Hannes Kuchelmeister commited on
Commit
d2e7940
1 Parent(s): 554c212

remove old template and use https://github.com/ashleve/lightning-hydra-template instead

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. model/.dockerignore +0 -4
  2. model/Dockerfile +0 -13
  3. model/LICENSE +0 -21
  4. model/README.md +0 -378
  5. model/base/__init__.py +0 -3
  6. model/base/base_data_loader.py +0 -61
  7. model/base/base_model.py +0 -25
  8. model/base/base_trainer.py +0 -151
  9. model/config.json +0 -50
  10. model/data_loader/data_loaders.py +0 -16
  11. model/docker-compose.yml +0 -12
  12. model/logger/__init__.py +0 -2
  13. model/logger/logger.py +0 -22
  14. model/logger/logger_config.json +0 -32
  15. model/logger/visualization.py +0 -73
  16. model/model/loss.py +0 -5
  17. model/model/metric.py +0 -20
  18. model/model/model.py +0 -22
  19. model/new_project.py +0 -18
  20. model/parse_config.py +0 -157
  21. model/requirements.txt +0 -6
  22. model/test.py +0 -81
  23. model/train.py +0 -73
  24. model/trainer/__init__.py +0 -1
  25. model/trainer/trainer.py +0 -110
  26. model/utils/__init__.py +0 -1
  27. model/utils/util.py +0 -67
  28. models/.env.example +7 -0
  29. models/.gitignore +148 -0
  30. models/.pre-commit-config.yaml +44 -0
  31. models/README.md +1445 -0
  32. models/configs/callbacks/default.yaml +24 -0
  33. models/configs/callbacks/none.yaml +0 -0
  34. models/configs/datamodule/mnist.yaml +7 -0
  35. models/configs/debug/default.yaml +28 -0
  36. models/configs/debug/limit_batches.yaml +12 -0
  37. models/configs/debug/overfit.yaml +10 -0
  38. models/configs/debug/profiler.yaml +12 -0
  39. models/configs/debug/step.yaml +9 -0
  40. models/configs/debug/test_only.yaml +9 -0
  41. models/configs/experiment/example.yaml +37 -0
  42. models/configs/hparams_search/mnist_optuna.yaml +60 -0
  43. models/configs/local/.gitkeep +0 -0
  44. models/configs/log_dir/debug.yaml +8 -0
  45. models/configs/log_dir/default.yaml +8 -0
  46. models/configs/log_dir/evaluation.yaml +8 -0
  47. models/configs/logger/comet.yaml +7 -0
  48. models/configs/logger/csv.yaml +7 -0
  49. models/configs/logger/many_loggers.yaml +9 -0
  50. models/configs/logger/mlflow.yaml +10 -0
model/.dockerignore DELETED
@@ -1,4 +0,0 @@
1
- .env
2
- in/
3
- out/
4
- new_proect.py
 
 
 
 
 
model/Dockerfile DELETED
@@ -1,13 +0,0 @@
1
- FROM python:3.7
2
-
3
- WORKDIR /usr/src/app
4
-
5
- RUN apt-get update
6
- RUN apt-get install libgl1 -y
7
-
8
- COPY requirements.txt ./
9
- RUN pip install --no-cache-dir -r requirements.txt
10
-
11
- COPY . .
12
-
13
- CMD sh -c "python train.py --config config.json"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/LICENSE DELETED
@@ -1,21 +0,0 @@
1
- MIT License
2
-
3
- Copyright (c) 2018 Victor Huang
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/README.md DELETED
@@ -1,378 +0,0 @@
1
- # PyTorch Template Project
2
- PyTorch deep learning project made easy.
3
-
4
- <!-- @import "[TOC]" {cmd="toc" depthFrom=1 depthTo=6 orderedList=false} -->
5
-
6
- <!-- code_chunk_output -->
7
-
8
- * [PyTorch Template Project](#pytorch-template-project)
9
- * [Requirements](#requirements)
10
- * [Features](#features)
11
- * [Folder Structure](#folder-structure)
12
- * [Usage](#usage)
13
- * [Config file format](#config-file-format)
14
- * [Using config files](#using-config-files)
15
- * [Resuming from checkpoints](#resuming-from-checkpoints)
16
- * [Using Multiple GPU](#using-multiple-gpu)
17
- * [Customization](#customization)
18
- * [Custom CLI options](#custom-cli-options)
19
- * [Data Loader](#data-loader)
20
- * [Trainer](#trainer)
21
- * [Model](#model)
22
- * [Loss](#loss)
23
- * [metrics](#metrics)
24
- * [Additional logging](#additional-logging)
25
- * [Validation data](#validation-data)
26
- * [Checkpoints](#checkpoints)
27
- * [Tensorboard Visualization](#tensorboard-visualization)
28
- * [Contribution](#contribution)
29
- * [TODOs](#todos)
30
- * [License](#license)
31
- * [Acknowledgements](#acknowledgements)
32
-
33
- <!-- /code_chunk_output -->
34
-
35
- ## Requirements
36
- * Python >= 3.5 (3.6 recommended)
37
- * PyTorch >= 0.4 (1.2 recommended)
38
- * tqdm (Optional for `test.py`)
39
- * tensorboard >= 1.14 (see [Tensorboard Visualization](#tensorboard-visualization))
40
-
41
- ## Features
42
- * Clear folder structure which is suitable for many deep learning projects.
43
- * `.json` config file support for convenient parameter tuning.
44
- * Customizable command line options for more convenient parameter tuning.
45
- * Checkpoint saving and resuming.
46
- * Abstract base classes for faster development:
47
- * `BaseTrainer` handles checkpoint saving/resuming, training process logging, and more.
48
- * `BaseDataLoader` handles batch generation, data shuffling, and validation data splitting.
49
- * `BaseModel` provides basic model summary.
50
-
51
- ## Folder Structure
52
- ```
53
- pytorch-template/
54
-
55
- ├── train.py - main script to start training
56
- ├── test.py - evaluation of trained model
57
-
58
- ├── config.json - holds configuration for training
59
- ├── parse_config.py - class to handle config file and cli options
60
-
61
- ├── new_project.py - initialize new project with template files
62
-
63
- ├── base/ - abstract base classes
64
- │ ├── base_data_loader.py
65
- │ ├── base_model.py
66
- │ └── base_trainer.py
67
-
68
- ├── data_loader/ - anything about data loading goes here
69
- │ └── data_loaders.py
70
-
71
- ├── data/ - default directory for storing input data
72
-
73
- ├── model/ - models, losses, and metrics
74
- │ ├── model.py
75
- │ ├── metric.py
76
- │ └── loss.py
77
-
78
- ├── saved/
79
- │ ├── models/ - trained models are saved here
80
- │ └── log/ - default logdir for tensorboard and logging output
81
-
82
- ├── trainer/ - trainers
83
- │ └── trainer.py
84
-
85
- ├── logger/ - module for tensorboard visualization and logging
86
- │ ├── visualization.py
87
- │ ├── logger.py
88
- │ └── logger_config.json
89
-
90
- └── utils/ - small utility functions
91
- ├── util.py
92
- └── ...
93
- ```
94
-
95
- ## Usage
96
- The code in this repo is an MNIST example of the template.
97
- Try `python train.py -c config.json` to run code.
98
-
99
- ### Config file format
100
- Config files are in `.json` format:
101
- ```javascript
102
- {
103
- "name": "Mnist_LeNet", // training session name
104
- "n_gpu": 1, // number of GPUs to use for training.
105
-
106
- "arch": {
107
- "type": "MnistModel", // name of model architecture to train
108
- "args": {
109
-
110
- }
111
- },
112
- "data_loader": {
113
- "type": "MnistDataLoader", // selecting data loader
114
- "args":{
115
- "data_dir": "data/", // dataset path
116
- "batch_size": 64, // batch size
117
- "shuffle": true, // shuffle training data before splitting
118
- "validation_split": 0.1 // size of validation dataset. float(portion) or int(number of samples)
119
- "num_workers": 2, // number of cpu processes to be used for data loading
120
- }
121
- },
122
- "optimizer": {
123
- "type": "Adam",
124
- "args":{
125
- "lr": 0.001, // learning rate
126
- "weight_decay": 0, // (optional) weight decay
127
- "amsgrad": true
128
- }
129
- },
130
- "loss": "nll_loss", // loss
131
- "metrics": [
132
- "accuracy", "top_k_acc" // list of metrics to evaluate
133
- ],
134
- "lr_scheduler": {
135
- "type": "StepLR", // learning rate scheduler
136
- "args":{
137
- "step_size": 50,
138
- "gamma": 0.1
139
- }
140
- },
141
- "trainer": {
142
- "epochs": 100, // number of training epochs
143
- "save_dir": "saved/", // checkpoints are saved in save_dir/models/name
144
- "save_freq": 1, // save checkpoints every save_freq epochs
145
- "verbosity": 2, // 0: quiet, 1: per epoch, 2: full
146
-
147
- "monitor": "min val_loss" // mode and metric for model performance monitoring. set 'off' to disable.
148
- "early_stop": 10 // number of epochs to wait before early stop. set 0 to disable.
149
-
150
- "tensorboard": true, // enable tensorboard visualization
151
- }
152
- }
153
- ```
154
-
155
- Add addional configurations if you need.
156
-
157
- ### Using config files
158
- Modify the configurations in `.json` config files, then run:
159
-
160
- ```
161
- python train.py --config config.json
162
- ```
163
-
164
- ### Resuming from checkpoints
165
- You can resume from a previously saved checkpoint by:
166
-
167
- ```
168
- python train.py --resume path/to/checkpoint
169
- ```
170
-
171
- ### Using Multiple GPU
172
- You can enable multi-GPU training by setting `n_gpu` argument of the config file to larger number.
173
- If configured to use smaller number of gpu than available, first n devices will be used by default.
174
- Specify indices of available GPUs by cuda environmental variable.
175
- ```
176
- python train.py --device 2,3 -c config.json
177
- ```
178
- This is equivalent to
179
- ```
180
- CUDA_VISIBLE_DEVICES=2,3 python train.py -c config.py
181
- ```
182
-
183
- ## Customization
184
-
185
- ### Project initialization
186
- Use the `new_project.py` script to make your new project directory with template files.
187
- `python new_project.py ../NewProject` then a new project folder named 'NewProject' will be made.
188
- This script will filter out unneccessary files like cache, git files or readme file.
189
-
190
- ### Custom CLI options
191
-
192
- Changing values of config file is a clean, safe and easy way of tuning hyperparameters. However, sometimes
193
- it is better to have command line options if some values need to be changed too often or quickly.
194
-
195
- This template uses the configurations stored in the json file by default, but by registering custom options as follows
196
- you can change some of them using CLI flags.
197
-
198
- ```python
199
- # simple class-like object having 3 attributes, `flags`, `type`, `target`.
200
- CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
201
- options = [
202
- CustomArgs(['--lr', '--learning_rate'], type=float, target=('optimizer', 'args', 'lr')),
203
- CustomArgs(['--bs', '--batch_size'], type=int, target=('data_loader', 'args', 'batch_size'))
204
- # options added here can be modified by command line flags.
205
- ]
206
- ```
207
- `target` argument should be sequence of keys, which are used to access that option in the config dict. In this example, `target`
208
- for the learning rate option is `('optimizer', 'args', 'lr')` because `config['optimizer']['args']['lr']` points to the learning rate.
209
- `python train.py -c config.json --bs 256` runs training with options given in `config.json` except for the `batch size`
210
- which is increased to 256 by command line options.
211
-
212
-
213
- ### Data Loader
214
- * **Writing your own data loader**
215
-
216
- 1. **Inherit ```BaseDataLoader```**
217
-
218
- `BaseDataLoader` is a subclass of `torch.utils.data.DataLoader`, you can use either of them.
219
-
220
- `BaseDataLoader` handles:
221
- * Generating next batch
222
- * Data shuffling
223
- * Generating validation data loader by calling
224
- `BaseDataLoader.split_validation()`
225
-
226
- * **DataLoader Usage**
227
-
228
- `BaseDataLoader` is an iterator, to iterate through batches:
229
- ```python
230
- for batch_idx, (x_batch, y_batch) in data_loader:
231
- pass
232
- ```
233
- * **Example**
234
-
235
- Please refer to `data_loader/data_loaders.py` for an MNIST data loading example.
236
-
237
- ### Trainer
238
- * **Writing your own trainer**
239
-
240
- 1. **Inherit ```BaseTrainer```**
241
-
242
- `BaseTrainer` handles:
243
- * Training process logging
244
- * Checkpoint saving
245
- * Checkpoint resuming
246
- * Reconfigurable performance monitoring for saving current best model, and early stop training.
247
- * If config `monitor` is set to `max val_accuracy`, which means then the trainer will save a checkpoint `model_best.pth` when `validation accuracy` of epoch replaces current `maximum`.
248
- * If config `early_stop` is set, training will be automatically terminated when model performance does not improve for given number of epochs. This feature can be turned off by passing 0 to the `early_stop` option, or just deleting the line of config.
249
-
250
- 2. **Implementing abstract methods**
251
-
252
- You need to implement `_train_epoch()` for your training process, if you need validation then you can implement `_valid_epoch()` as in `trainer/trainer.py`
253
-
254
- * **Example**
255
-
256
- Please refer to `trainer/trainer.py` for MNIST training.
257
-
258
- * **Iteration-based training**
259
-
260
- `Trainer.__init__` takes an optional argument, `len_epoch` which controls number of batches(steps) in each epoch.
261
-
262
- ### Model
263
- * **Writing your own model**
264
-
265
- 1. **Inherit `BaseModel`**
266
-
267
- `BaseModel` handles:
268
- * Inherited from `torch.nn.Module`
269
- * `__str__`: Modify native `print` function to prints the number of trainable parameters.
270
-
271
- 2. **Implementing abstract methods**
272
-
273
- Implement the foward pass method `forward()`
274
-
275
- * **Example**
276
-
277
- Please refer to `model/model.py` for a LeNet example.
278
-
279
- ### Loss
280
- Custom loss functions can be implemented in 'model/loss.py'. Use them by changing the name given in "loss" in config file, to corresponding name.
281
-
282
- ### Metrics
283
- Metric functions are located in 'model/metric.py'.
284
-
285
- You can monitor multiple metrics by providing a list in the configuration file, e.g.:
286
- ```json
287
- "metrics": ["accuracy", "top_k_acc"],
288
- ```
289
-
290
- ### Additional logging
291
- If you have additional information to be logged, in `_train_epoch()` of your trainer class, merge them with `log` as shown below before returning:
292
-
293
- ```python
294
- additional_log = {"gradient_norm": g, "sensitivity": s}
295
- log.update(additional_log)
296
- return log
297
- ```
298
-
299
- ### Testing
300
- You can test trained model by running `test.py` passing path to the trained checkpoint by `--resume` argument.
301
-
302
- ### Validation data
303
- To split validation data from a data loader, call `BaseDataLoader.split_validation()`, then it will return a data loader for validation of size specified in your config file.
304
- The `validation_split` can be a ratio of validation set per total data(0.0 <= float < 1.0), or the number of samples (0 <= int < `n_total_samples`).
305
-
306
- **Note**: the `split_validation()` method will modify the original data loader
307
- **Note**: `split_validation()` will return `None` if `"validation_split"` is set to `0`
308
-
309
- ### Checkpoints
310
- You can specify the name of the training session in config files:
311
- ```json
312
- "name": "MNIST_LeNet",
313
- ```
314
-
315
- The checkpoints will be saved in `save_dir/name/timestamp/checkpoint_epoch_n`, with timestamp in mmdd_HHMMSS format.
316
-
317
- A copy of config file will be saved in the same folder.
318
-
319
- **Note**: checkpoints contain:
320
- ```python
321
- {
322
- 'arch': arch,
323
- 'epoch': epoch,
324
- 'state_dict': self.model.state_dict(),
325
- 'optimizer': self.optimizer.state_dict(),
326
- 'monitor_best': self.mnt_best,
327
- 'config': self.config
328
- }
329
- ```
330
-
331
- ### Tensorboard Visualization
332
- This template supports Tensorboard visualization by using either `torch.utils.tensorboard` or [TensorboardX](https://github.com/lanpa/tensorboardX).
333
-
334
- 1. **Install**
335
-
336
- If you are using pytorch 1.1 or higher, install tensorboard by 'pip install tensorboard>=1.14.0'.
337
-
338
- Otherwise, you should install tensorboardx. Follow installation guide in [TensorboardX](https://github.com/lanpa/tensorboardX).
339
-
340
- 2. **Run training**
341
-
342
- Make sure that `tensorboard` option in the config file is turned on.
343
-
344
- ```
345
- "tensorboard" : true
346
- ```
347
-
348
- 3. **Open Tensorboard server**
349
-
350
- Type `tensorboard --logdir saved/log/` at the project root, then server will open at `http://localhost:6006`
351
-
352
- By default, values of loss and metrics specified in config file, input images, and histogram of model parameters will be logged.
353
- If you need more visualizations, use `add_scalar('tag', data)`, `add_image('tag', image)`, etc in the `trainer._train_epoch` method.
354
- `add_something()` methods in this template are basically wrappers for those of `tensorboardX.SummaryWriter` and `torch.utils.tensorboard.SummaryWriter` modules.
355
-
356
- **Note**: You don't have to specify current steps, since `WriterTensorboard` class defined at `logger/visualization.py` will track current steps.
357
-
358
- ## Contribution
359
- Feel free to contribute any kind of function or enhancement, here the coding style follows PEP8
360
-
361
- Code should pass the [Flake8](http://flake8.pycqa.org/en/latest/) check before committing.
362
-
363
- ## TODOs
364
-
365
- - [ ] Multiple optimizers
366
- - [ ] Support more tensorboard functions
367
- - [x] Using fixed random seed
368
- - [x] Support pytorch native tensorboard
369
- - [x] `tensorboardX` logger support
370
- - [x] Configurable logging layout, checkpoint naming
371
- - [x] Iteration-based training (instead of epoch-based)
372
- - [x] Adding command line option for fine-tuning
373
-
374
- ## License
375
- This project is licensed under the MIT License. See LICENSE for more details
376
-
377
- ## Acknowledgements
378
- This project is inspired by the project [Tensorflow-Project-Template](https://github.com/MrGemy95/Tensorflow-Project-Template) by [Mahmoud Gemy](https://github.com/MrGemy95)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/base/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from .base_data_loader import *
2
- from .base_model import *
3
- from .base_trainer import *
 
 
 
 
model/base/base_data_loader.py DELETED
@@ -1,61 +0,0 @@
1
- import numpy as np
2
- from torch.utils.data import DataLoader
3
- from torch.utils.data.dataloader import default_collate
4
- from torch.utils.data.sampler import SubsetRandomSampler
5
-
6
-
7
- class BaseDataLoader(DataLoader):
8
- """
9
- Base class for all data loaders
10
- """
11
- def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate):
12
- self.validation_split = validation_split
13
- self.shuffle = shuffle
14
-
15
- self.batch_idx = 0
16
- self.n_samples = len(dataset)
17
-
18
- self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)
19
-
20
- self.init_kwargs = {
21
- 'dataset': dataset,
22
- 'batch_size': batch_size,
23
- 'shuffle': self.shuffle,
24
- 'collate_fn': collate_fn,
25
- 'num_workers': num_workers
26
- }
27
- super().__init__(sampler=self.sampler, **self.init_kwargs)
28
-
29
- def _split_sampler(self, split):
30
- if split == 0.0:
31
- return None, None
32
-
33
- idx_full = np.arange(self.n_samples)
34
-
35
- np.random.seed(0)
36
- np.random.shuffle(idx_full)
37
-
38
- if isinstance(split, int):
39
- assert split > 0
40
- assert split < self.n_samples, "validation set size is configured to be larger than entire dataset."
41
- len_valid = split
42
- else:
43
- len_valid = int(self.n_samples * split)
44
-
45
- valid_idx = idx_full[0:len_valid]
46
- train_idx = np.delete(idx_full, np.arange(0, len_valid))
47
-
48
- train_sampler = SubsetRandomSampler(train_idx)
49
- valid_sampler = SubsetRandomSampler(valid_idx)
50
-
51
- # turn off shuffle option which is mutually exclusive with sampler
52
- self.shuffle = False
53
- self.n_samples = len(train_idx)
54
-
55
- return train_sampler, valid_sampler
56
-
57
- def split_validation(self):
58
- if self.valid_sampler is None:
59
- return None
60
- else:
61
- return DataLoader(sampler=self.valid_sampler, **self.init_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/base/base_model.py DELETED
@@ -1,25 +0,0 @@
1
- import torch.nn as nn
2
- import numpy as np
3
- from abc import abstractmethod
4
-
5
-
6
- class BaseModel(nn.Module):
7
- """
8
- Base class for all models
9
- """
10
- @abstractmethod
11
- def forward(self, *inputs):
12
- """
13
- Forward pass logic
14
-
15
- :return: Model output
16
- """
17
- raise NotImplementedError
18
-
19
- def __str__(self):
20
- """
21
- Model prints with number of trainable parameters
22
- """
23
- model_parameters = filter(lambda p: p.requires_grad, self.parameters())
24
- params = sum([np.prod(p.size()) for p in model_parameters])
25
- return super().__str__() + '\nTrainable parameters: {}'.format(params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/base/base_trainer.py DELETED
@@ -1,151 +0,0 @@
1
- import torch
2
- from abc import abstractmethod
3
- from numpy import inf
4
- from logger import TensorboardWriter
5
-
6
-
7
- class BaseTrainer:
8
- """
9
- Base class for all trainers
10
- """
11
- def __init__(self, model, criterion, metric_ftns, optimizer, config):
12
- self.config = config
13
- self.logger = config.get_logger('trainer', config['trainer']['verbosity'])
14
-
15
- self.model = model
16
- self.criterion = criterion
17
- self.metric_ftns = metric_ftns
18
- self.optimizer = optimizer
19
-
20
- cfg_trainer = config['trainer']
21
- self.epochs = cfg_trainer['epochs']
22
- self.save_period = cfg_trainer['save_period']
23
- self.monitor = cfg_trainer.get('monitor', 'off')
24
-
25
- # configuration to monitor model performance and save best
26
- if self.monitor == 'off':
27
- self.mnt_mode = 'off'
28
- self.mnt_best = 0
29
- else:
30
- self.mnt_mode, self.mnt_metric = self.monitor.split()
31
- assert self.mnt_mode in ['min', 'max']
32
-
33
- self.mnt_best = inf if self.mnt_mode == 'min' else -inf
34
- self.early_stop = cfg_trainer.get('early_stop', inf)
35
- if self.early_stop <= 0:
36
- self.early_stop = inf
37
-
38
- self.start_epoch = 1
39
-
40
- self.checkpoint_dir = config.save_dir
41
-
42
- # setup visualization writer instance
43
- self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard'])
44
-
45
- if config.resume is not None:
46
- self._resume_checkpoint(config.resume)
47
-
48
- @abstractmethod
49
- def _train_epoch(self, epoch):
50
- """
51
- Training logic for an epoch
52
-
53
- :param epoch: Current epoch number
54
- """
55
- raise NotImplementedError
56
-
57
- def train(self):
58
- """
59
- Full training logic
60
- """
61
- not_improved_count = 0
62
- for epoch in range(self.start_epoch, self.epochs + 1):
63
- result = self._train_epoch(epoch)
64
-
65
- # save logged informations into log dict
66
- log = {'epoch': epoch}
67
- log.update(result)
68
-
69
- # print logged informations to the screen
70
- for key, value in log.items():
71
- self.logger.info(' {:15s}: {}'.format(str(key), value))
72
-
73
- # evaluate model performance according to configured metric, save best checkpoint as model_best
74
- best = False
75
- if self.mnt_mode != 'off':
76
- try:
77
- # check whether model performance improved or not, according to specified metric(mnt_metric)
78
- improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \
79
- (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best)
80
- except KeyError:
81
- self.logger.warning("Warning: Metric '{}' is not found. "
82
- "Model performance monitoring is disabled.".format(self.mnt_metric))
83
- self.mnt_mode = 'off'
84
- improved = False
85
-
86
- if improved:
87
- self.mnt_best = log[self.mnt_metric]
88
- not_improved_count = 0
89
- best = True
90
- else:
91
- not_improved_count += 1
92
-
93
- if not_improved_count > self.early_stop:
94
- self.logger.info("Validation performance didn\'t improve for {} epochs. "
95
- "Training stops.".format(self.early_stop))
96
- break
97
-
98
- if epoch % self.save_period == 0:
99
- self._save_checkpoint(epoch, save_best=best)
100
-
101
- def _save_checkpoint(self, epoch, save_best=False):
102
- """
103
- Saving checkpoints
104
-
105
- :param epoch: current epoch number
106
- :param log: logging information of the epoch
107
- :param save_best: if True, rename the saved checkpoint to 'model_best.pth'
108
- """
109
- arch = type(self.model).__name__
110
- state = {
111
- 'arch': arch,
112
- 'epoch': epoch,
113
- 'state_dict': self.model.state_dict(),
114
- 'optimizer': self.optimizer.state_dict(),
115
- 'monitor_best': self.mnt_best,
116
- 'config': self.config
117
- }
118
- filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch))
119
- torch.save(state, filename)
120
- self.logger.info("Saving checkpoint: {} ...".format(filename))
121
- if save_best:
122
- best_path = str(self.checkpoint_dir / 'model_best.pth')
123
- torch.save(state, best_path)
124
- self.logger.info("Saving current best: model_best.pth ...")
125
-
126
- def _resume_checkpoint(self, resume_path):
127
- """
128
- Resume from saved checkpoints
129
-
130
- :param resume_path: Checkpoint path to be resumed
131
- """
132
- resume_path = str(resume_path)
133
- self.logger.info("Loading checkpoint: {} ...".format(resume_path))
134
- checkpoint = torch.load(resume_path)
135
- self.start_epoch = checkpoint['epoch'] + 1
136
- self.mnt_best = checkpoint['monitor_best']
137
-
138
- # load architecture params from checkpoint.
139
- if checkpoint['config']['arch'] != self.config['arch']:
140
- self.logger.warning("Warning: Architecture configuration given in config file is different from that of "
141
- "checkpoint. This may yield an exception while state_dict is being loaded.")
142
- self.model.load_state_dict(checkpoint['state_dict'])
143
-
144
- # load optimizer state from checkpoint only when optimizer type is not changed.
145
- if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']:
146
- self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. "
147
- "Optimizer parameters not being resumed.")
148
- else:
149
- self.optimizer.load_state_dict(checkpoint['optimizer'])
150
-
151
- self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/config.json DELETED
@@ -1,50 +0,0 @@
1
- {
2
- "name": "Mnist_LeNet",
3
- "n_gpu": 1,
4
-
5
- "arch": {
6
- "type": "MnistModel",
7
- "args": {}
8
- },
9
- "data_loader": {
10
- "type": "MnistDataLoader",
11
- "args":{
12
- "data_dir": "data/",
13
- "batch_size": 128,
14
- "shuffle": true,
15
- "validation_split": 0.1,
16
- "num_workers": 2
17
- }
18
- },
19
- "optimizer": {
20
- "type": "Adam",
21
- "args":{
22
- "lr": 0.001,
23
- "weight_decay": 0,
24
- "amsgrad": true
25
- }
26
- },
27
- "loss": "nll_loss",
28
- "metrics": [
29
- "accuracy", "top_k_acc"
30
- ],
31
- "lr_scheduler": {
32
- "type": "StepLR",
33
- "args": {
34
- "step_size": 50,
35
- "gamma": 0.1
36
- }
37
- },
38
- "trainer": {
39
- "epochs": 100,
40
-
41
- "save_dir": "saved/",
42
- "save_period": 1,
43
- "verbosity": 2,
44
-
45
- "monitor": "min val_loss",
46
- "early_stop": 10,
47
-
48
- "tensorboard": true
49
- }
50
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/data_loader/data_loaders.py DELETED
@@ -1,16 +0,0 @@
1
- from torchvision import datasets, transforms
2
- from base import BaseDataLoader
3
-
4
-
5
- class MnistDataLoader(BaseDataLoader):
6
- """
7
- MNIST data loading demo using BaseDataLoader
8
- """
9
- def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_workers=1, training=True):
10
- trsfm = transforms.Compose([
11
- transforms.ToTensor(),
12
- transforms.Normalize((0.1307,), (0.3081,))
13
- ])
14
- self.data_dir = data_dir
15
- self.dataset = datasets.MNIST(self.data_dir, train=training, download=True, transform=trsfm)
16
- super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/docker-compose.yml DELETED
@@ -1,12 +0,0 @@
1
- version: "3" # optional since v1.27.0
2
- services:
3
- model:
4
- build: .
5
- ports:
6
- - 6006:6006
7
- volumes:
8
- - ./saved:/usr/src/app/saved:z
9
- # - ./out/:/usr/src/app/out:z
10
- # - ./in/:/usr/src/app/in:z
11
- #env_file:
12
- # - .env
 
 
 
 
 
 
 
 
 
 
 
 
 
model/logger/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- from .logger import *
2
- from .visualization import *
 
 
 
model/logger/logger.py DELETED
@@ -1,22 +0,0 @@
1
- import logging
2
- import logging.config
3
- from pathlib import Path
4
- from utils import read_json
5
-
6
-
7
- def setup_logging(save_dir, log_config='logger/logger_config.json', default_level=logging.INFO):
8
- """
9
- Setup logging configuration
10
- """
11
- log_config = Path(log_config)
12
- if log_config.is_file():
13
- config = read_json(log_config)
14
- # modify logging paths based on run config
15
- for _, handler in config['handlers'].items():
16
- if 'filename' in handler:
17
- handler['filename'] = str(save_dir / handler['filename'])
18
-
19
- logging.config.dictConfig(config)
20
- else:
21
- print("Warning: logging configuration file is not found in {}.".format(log_config))
22
- logging.basicConfig(level=default_level)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/logger/logger_config.json DELETED
@@ -1,32 +0,0 @@
1
-
2
- {
3
- "version": 1,
4
- "disable_existing_loggers": false,
5
- "formatters": {
6
- "simple": {"format": "%(message)s"},
7
- "datetime": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"}
8
- },
9
- "handlers": {
10
- "console": {
11
- "class": "logging.StreamHandler",
12
- "level": "DEBUG",
13
- "formatter": "simple",
14
- "stream": "ext://sys.stdout"
15
- },
16
- "info_file_handler": {
17
- "class": "logging.handlers.RotatingFileHandler",
18
- "level": "INFO",
19
- "formatter": "datetime",
20
- "filename": "info.log",
21
- "maxBytes": 10485760,
22
- "backupCount": 20, "encoding": "utf8"
23
- }
24
- },
25
- "root": {
26
- "level": "INFO",
27
- "handlers": [
28
- "console",
29
- "info_file_handler"
30
- ]
31
- }
32
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/logger/visualization.py DELETED
@@ -1,73 +0,0 @@
1
- import importlib
2
- from datetime import datetime
3
-
4
-
5
- class TensorboardWriter():
6
- def __init__(self, log_dir, logger, enabled):
7
- self.writer = None
8
- self.selected_module = ""
9
-
10
- if enabled:
11
- log_dir = str(log_dir)
12
-
13
- # Retrieve vizualization writer.
14
- succeeded = False
15
- for module in ["torch.utils.tensorboard", "tensorboardX"]:
16
- try:
17
- self.writer = importlib.import_module(module).SummaryWriter(log_dir)
18
- succeeded = True
19
- break
20
- except ImportError:
21
- succeeded = False
22
- self.selected_module = module
23
-
24
- if not succeeded:
25
- message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \
26
- "this machine. Please install TensorboardX with 'pip install tensorboardx', upgrade PyTorch to " \
27
- "version >= 1.1 to use 'torch.utils.tensorboard' or turn off the option in the 'config.json' file."
28
- logger.warning(message)
29
-
30
- self.step = 0
31
- self.mode = ''
32
-
33
- self.tb_writer_ftns = {
34
- 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio',
35
- 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding'
36
- }
37
- self.tag_mode_exceptions = {'add_histogram', 'add_embedding'}
38
- self.timer = datetime.now()
39
-
40
- def set_step(self, step, mode='train'):
41
- self.mode = mode
42
- self.step = step
43
- if step == 0:
44
- self.timer = datetime.now()
45
- else:
46
- duration = datetime.now() - self.timer
47
- self.add_scalar('steps_per_sec', 1 / duration.total_seconds())
48
- self.timer = datetime.now()
49
-
50
- def __getattr__(self, name):
51
- """
52
- If visualization is configured to use:
53
- return add_data() methods of tensorboard with additional information (step, tag) added.
54
- Otherwise:
55
- return a blank function handle that does nothing
56
- """
57
- if name in self.tb_writer_ftns:
58
- add_data = getattr(self.writer, name, None)
59
-
60
- def wrapper(tag, data, *args, **kwargs):
61
- if add_data is not None:
62
- # add mode(train/valid) tag
63
- if name not in self.tag_mode_exceptions:
64
- tag = '{}/{}'.format(tag, self.mode)
65
- add_data(tag, data, self.step, *args, **kwargs)
66
- return wrapper
67
- else:
68
- # default action for returning methods defined in this class, set_step() for instance.
69
- try:
70
- attr = object.__getattr__(name)
71
- except AttributeError:
72
- raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name))
73
- return attr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/model/loss.py DELETED
@@ -1,5 +0,0 @@
1
- import torch.nn.functional as F
2
-
3
-
4
- def nll_loss(output, target):
5
- return F.nll_loss(output, target)
 
 
 
 
 
 
model/model/metric.py DELETED
@@ -1,20 +0,0 @@
1
- import torch
2
-
3
-
4
- def accuracy(output, target):
5
- with torch.no_grad():
6
- pred = torch.argmax(output, dim=1)
7
- assert pred.shape[0] == len(target)
8
- correct = 0
9
- correct += torch.sum(pred == target).item()
10
- return correct / len(target)
11
-
12
-
13
- def top_k_acc(output, target, k=3):
14
- with torch.no_grad():
15
- pred = torch.topk(output, k, dim=1)[1]
16
- assert pred.shape[0] == len(target)
17
- correct = 0
18
- for i in range(k):
19
- correct += torch.sum(pred[:, i] == target).item()
20
- return correct / len(target)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/model/model.py DELETED
@@ -1,22 +0,0 @@
1
- import torch.nn as nn
2
- import torch.nn.functional as F
3
- from base import BaseModel
4
-
5
-
6
- class MnistModel(BaseModel):
7
- def __init__(self, num_classes=10):
8
- super().__init__()
9
- self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
10
- self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
11
- self.conv2_drop = nn.Dropout2d()
12
- self.fc1 = nn.Linear(320, 50)
13
- self.fc2 = nn.Linear(50, num_classes)
14
-
15
- def forward(self, x):
16
- x = F.relu(F.max_pool2d(self.conv1(x), 2))
17
- x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
18
- x = x.view(-1, 320)
19
- x = F.relu(self.fc1(x))
20
- x = F.dropout(x, training=self.training)
21
- x = self.fc2(x)
22
- return F.log_softmax(x, dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/new_project.py DELETED
@@ -1,18 +0,0 @@
1
- import sys
2
- from pathlib import Path
3
- from shutil import copytree, ignore_patterns
4
-
5
-
6
- # This script initializes new pytorch project with the template files.
7
- # Run `python3 new_project.py ../MyNewProject` then new project named
8
- # MyNewProject will be made
9
- current_dir = Path()
10
- assert (current_dir / 'new_project.py').is_file(), 'Script should be executed in the pytorch-template directory'
11
- assert len(sys.argv) == 2, 'Specify a name for the new project. Example: python3 new_project.py MyNewProject'
12
-
13
- project_name = Path(sys.argv[1])
14
- target_dir = current_dir / project_name
15
-
16
- ignore = [".git", "data", "saved", "new_project.py", "LICENSE", ".flake8", "README.md", "__pycache__"]
17
- copytree(current_dir, target_dir, ignore=ignore_patterns(*ignore))
18
- print('New project initialized at', target_dir.absolute().resolve())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/parse_config.py DELETED
@@ -1,157 +0,0 @@
1
- import os
2
- import logging
3
- from pathlib import Path
4
- from functools import reduce, partial
5
- from operator import getitem
6
- from datetime import datetime
7
- from logger import setup_logging
8
- from utils import read_json, write_json
9
-
10
-
11
- class ConfigParser:
12
- def __init__(self, config, resume=None, modification=None, run_id=None):
13
- """
14
- class to parse configuration json file. Handles hyperparameters for training, initializations of modules, checkpoint saving
15
- and logging module.
16
- :param config: Dict containing configurations, hyperparameters for training. contents of `config.json` file for example.
17
- :param resume: String, path to the checkpoint being loaded.
18
- :param modification: Dict keychain:value, specifying position values to be replaced from config dict.
19
- :param run_id: Unique Identifier for training processes. Used to save checkpoints and training log. Timestamp is being used as default
20
- """
21
- # load config file and apply modification
22
- self._config = _update_config(config, modification)
23
- self.resume = resume
24
-
25
- # set save_dir where trained model and log will be saved.
26
- save_dir = Path(self.config['trainer']['save_dir'])
27
-
28
- exper_name = self.config['name']
29
- if run_id is None: # use timestamp as default run-id
30
- run_id = datetime.now().strftime(r'%m%d_%H%M%S')
31
- self._save_dir = save_dir / 'models' / exper_name / run_id
32
- self._log_dir = save_dir / 'log' / exper_name / run_id
33
-
34
- # make directory for saving checkpoints and log.
35
- exist_ok = run_id == ''
36
- self.save_dir.mkdir(parents=True, exist_ok=exist_ok)
37
- self.log_dir.mkdir(parents=True, exist_ok=exist_ok)
38
-
39
- # save updated config file to the checkpoint dir
40
- write_json(self.config, self.save_dir / 'config.json')
41
-
42
- # configure logging module
43
- setup_logging(self.log_dir)
44
- self.log_levels = {
45
- 0: logging.WARNING,
46
- 1: logging.INFO,
47
- 2: logging.DEBUG
48
- }
49
-
50
- @classmethod
51
- def from_args(cls, args, options=''):
52
- """
53
- Initialize this class from some cli arguments. Used in train, test.
54
- """
55
- for opt in options:
56
- args.add_argument(*opt.flags, default=None, type=opt.type)
57
- if not isinstance(args, tuple):
58
- args = args.parse_args()
59
-
60
- if args.device is not None:
61
- os.environ["CUDA_VISIBLE_DEVICES"] = args.device
62
- if args.resume is not None:
63
- resume = Path(args.resume)
64
- cfg_fname = resume.parent / 'config.json'
65
- else:
66
- msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example."
67
- assert args.config is not None, msg_no_cfg
68
- resume = None
69
- cfg_fname = Path(args.config)
70
-
71
- config = read_json(cfg_fname)
72
- if args.config and resume:
73
- # update new config for fine-tuning
74
- config.update(read_json(args.config))
75
-
76
- # parse custom cli options into dictionary
77
- modification = {opt.target : getattr(args, _get_opt_name(opt.flags)) for opt in options}
78
- return cls(config, resume, modification)
79
-
80
- def init_obj(self, name, module, *args, **kwargs):
81
- """
82
- Finds a function handle with the name given as 'type' in config, and returns the
83
- instance initialized with corresponding arguments given.
84
-
85
- `object = config.init_obj('name', module, a, b=1)`
86
- is equivalent to
87
- `object = module.name(a, b=1)`
88
- """
89
- module_name = self[name]['type']
90
- module_args = dict(self[name]['args'])
91
- assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
92
- module_args.update(kwargs)
93
- return getattr(module, module_name)(*args, **module_args)
94
-
95
- def init_ftn(self, name, module, *args, **kwargs):
96
- """
97
- Finds a function handle with the name given as 'type' in config, and returns the
98
- function with given arguments fixed with functools.partial.
99
-
100
- `function = config.init_ftn('name', module, a, b=1)`
101
- is equivalent to
102
- `function = lambda *args, **kwargs: module.name(a, *args, b=1, **kwargs)`.
103
- """
104
- module_name = self[name]['type']
105
- module_args = dict(self[name]['args'])
106
- assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
107
- module_args.update(kwargs)
108
- return partial(getattr(module, module_name), *args, **module_args)
109
-
110
- def __getitem__(self, name):
111
- """Access items like ordinary dict."""
112
- return self.config[name]
113
-
114
- def get_logger(self, name, verbosity=2):
115
- msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity, self.log_levels.keys())
116
- assert verbosity in self.log_levels, msg_verbosity
117
- logger = logging.getLogger(name)
118
- logger.setLevel(self.log_levels[verbosity])
119
- return logger
120
-
121
- # setting read-only attributes
122
- @property
123
- def config(self):
124
- return self._config
125
-
126
- @property
127
- def save_dir(self):
128
- return self._save_dir
129
-
130
- @property
131
- def log_dir(self):
132
- return self._log_dir
133
-
134
- # helper functions to update config dict with custom cli options
135
- def _update_config(config, modification):
136
- if modification is None:
137
- return config
138
-
139
- for k, v in modification.items():
140
- if v is not None:
141
- _set_by_path(config, k, v)
142
- return config
143
-
144
- def _get_opt_name(flags):
145
- for flg in flags:
146
- if flg.startswith('--'):
147
- return flg.replace('--', '')
148
- return flags[0].replace('--', '')
149
-
150
- def _set_by_path(tree, keys, value):
151
- """Set a value in a nested object in tree by sequence of keys."""
152
- keys = keys.split(';')
153
- _get_by_path(tree, keys[:-1])[keys[-1]] = value
154
-
155
- def _get_by_path(tree, keys):
156
- """Access a nested object in tree by sequence of keys."""
157
- return reduce(getitem, keys, tree)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/requirements.txt DELETED
@@ -1,6 +0,0 @@
1
- torch>=1.1
2
- torchvision
3
- numpy
4
- tqdm
5
- tensorboardx
6
- pandas
 
 
 
 
 
 
 
model/test.py DELETED
@@ -1,81 +0,0 @@
1
- import argparse
2
- import torch
3
- from tqdm import tqdm
4
- import data_loader.data_loaders as module_data
5
- import model.loss as module_loss
6
- import model.metric as module_metric
7
- import model.model as module_arch
8
- from parse_config import ConfigParser
9
-
10
-
11
- def main(config):
12
- logger = config.get_logger('test')
13
-
14
- # setup data_loader instances
15
- data_loader = getattr(module_data, config['data_loader']['type'])(
16
- config['data_loader']['args']['data_dir'],
17
- batch_size=512,
18
- shuffle=False,
19
- validation_split=0.0,
20
- training=False,
21
- num_workers=2
22
- )
23
-
24
- # build model architecture
25
- model = config.init_obj('arch', module_arch)
26
- logger.info(model)
27
-
28
- # get function handles of loss and metrics
29
- loss_fn = getattr(module_loss, config['loss'])
30
- metric_fns = [getattr(module_metric, met) for met in config['metrics']]
31
-
32
- logger.info('Loading checkpoint: {} ...'.format(config.resume))
33
- checkpoint = torch.load(config.resume)
34
- state_dict = checkpoint['state_dict']
35
- if config['n_gpu'] > 1:
36
- model = torch.nn.DataParallel(model)
37
- model.load_state_dict(state_dict)
38
-
39
- # prepare model for testing
40
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
41
- model = model.to(device)
42
- model.eval()
43
-
44
- total_loss = 0.0
45
- total_metrics = torch.zeros(len(metric_fns))
46
-
47
- with torch.no_grad():
48
- for i, (data, target) in enumerate(tqdm(data_loader)):
49
- data, target = data.to(device), target.to(device)
50
- output = model(data)
51
-
52
- #
53
- # save sample images, or do something with output here
54
- #
55
-
56
- # computing loss, metrics on test set
57
- loss = loss_fn(output, target)
58
- batch_size = data.shape[0]
59
- total_loss += loss.item() * batch_size
60
- for i, metric in enumerate(metric_fns):
61
- total_metrics[i] += metric(output, target) * batch_size
62
-
63
- n_samples = len(data_loader.sampler)
64
- log = {'loss': total_loss / n_samples}
65
- log.update({
66
- met.__name__: total_metrics[i].item() / n_samples for i, met in enumerate(metric_fns)
67
- })
68
- logger.info(log)
69
-
70
-
71
- if __name__ == '__main__':
72
- args = argparse.ArgumentParser(description='PyTorch Template')
73
- args.add_argument('-c', '--config', default=None, type=str,
74
- help='config file path (default: None)')
75
- args.add_argument('-r', '--resume', default=None, type=str,
76
- help='path to latest checkpoint (default: None)')
77
- args.add_argument('-d', '--device', default=None, type=str,
78
- help='indices of GPUs to enable (default: all)')
79
-
80
- config = ConfigParser.from_args(args)
81
- main(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/train.py DELETED
@@ -1,73 +0,0 @@
1
- import argparse
2
- import collections
3
- import torch
4
- import numpy as np
5
- import data_loader.data_loaders as module_data
6
- import model.loss as module_loss
7
- import model.metric as module_metric
8
- import model.model as module_arch
9
- from parse_config import ConfigParser
10
- from trainer import Trainer
11
- from utils import prepare_device
12
-
13
-
14
- # fix random seeds for reproducibility
15
- SEED = 123
16
- torch.manual_seed(SEED)
17
- torch.backends.cudnn.deterministic = True
18
- torch.backends.cudnn.benchmark = False
19
- np.random.seed(SEED)
20
-
21
- def main(config):
22
- logger = config.get_logger('train')
23
-
24
- # setup data_loader instances
25
- data_loader = config.init_obj('data_loader', module_data)
26
- valid_data_loader = data_loader.split_validation()
27
-
28
- # build model architecture, then print to console
29
- model = config.init_obj('arch', module_arch)
30
- logger.info(model)
31
-
32
- # prepare for (multi-device) GPU training
33
- device, device_ids = prepare_device(config['n_gpu'])
34
- model = model.to(device)
35
- if len(device_ids) > 1:
36
- model = torch.nn.DataParallel(model, device_ids=device_ids)
37
-
38
- # get function handles of loss and metrics
39
- criterion = getattr(module_loss, config['loss'])
40
- metrics = [getattr(module_metric, met) for met in config['metrics']]
41
-
42
- # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
43
- trainable_params = filter(lambda p: p.requires_grad, model.parameters())
44
- optimizer = config.init_obj('optimizer', torch.optim, trainable_params)
45
- lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer)
46
-
47
- trainer = Trainer(model, criterion, metrics, optimizer,
48
- config=config,
49
- device=device,
50
- data_loader=data_loader,
51
- valid_data_loader=valid_data_loader,
52
- lr_scheduler=lr_scheduler)
53
-
54
- trainer.train()
55
-
56
-
57
- if __name__ == '__main__':
58
- args = argparse.ArgumentParser(description='PyTorch Template')
59
- args.add_argument('-c', '--config', default=None, type=str,
60
- help='config file path (default: None)')
61
- args.add_argument('-r', '--resume', default=None, type=str,
62
- help='path to latest checkpoint (default: None)')
63
- args.add_argument('-d', '--device', default=None, type=str,
64
- help='indices of GPUs to enable (default: all)')
65
-
66
- # custom cli options to modify configuration from default values given in json file.
67
- CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
68
- options = [
69
- CustomArgs(['--lr', '--learning_rate'], type=float, target='optimizer;args;lr'),
70
- CustomArgs(['--bs', '--batch_size'], type=int, target='data_loader;args;batch_size')
71
- ]
72
- config = ConfigParser.from_args(args, options)
73
- main(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/trainer/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .trainer import *
 
 
model/trainer/trainer.py DELETED
@@ -1,110 +0,0 @@
1
- import numpy as np
2
- import torch
3
- from torchvision.utils import make_grid
4
- from base import BaseTrainer
5
- from utils import inf_loop, MetricTracker
6
-
7
-
8
- class Trainer(BaseTrainer):
9
- """
10
- Trainer class
11
- """
12
- def __init__(self, model, criterion, metric_ftns, optimizer, config, device,
13
- data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None):
14
- super().__init__(model, criterion, metric_ftns, optimizer, config)
15
- self.config = config
16
- self.device = device
17
- self.data_loader = data_loader
18
- if len_epoch is None:
19
- # epoch-based training
20
- self.len_epoch = len(self.data_loader)
21
- else:
22
- # iteration-based training
23
- self.data_loader = inf_loop(data_loader)
24
- self.len_epoch = len_epoch
25
- self.valid_data_loader = valid_data_loader
26
- self.do_validation = self.valid_data_loader is not None
27
- self.lr_scheduler = lr_scheduler
28
- self.log_step = int(np.sqrt(data_loader.batch_size))
29
-
30
- self.train_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)
31
- self.valid_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)
32
-
33
- def _train_epoch(self, epoch):
34
- """
35
- Training logic for an epoch
36
-
37
- :param epoch: Integer, current training epoch.
38
- :return: A log that contains average loss and metric in this epoch.
39
- """
40
- self.model.train()
41
- self.train_metrics.reset()
42
- for batch_idx, (data, target) in enumerate(self.data_loader):
43
- data, target = data.to(self.device), target.to(self.device)
44
-
45
- self.optimizer.zero_grad()
46
- output = self.model(data)
47
- loss = self.criterion(output, target)
48
- loss.backward()
49
- self.optimizer.step()
50
-
51
- self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
52
- self.train_metrics.update('loss', loss.item())
53
- for met in self.metric_ftns:
54
- self.train_metrics.update(met.__name__, met(output, target))
55
-
56
- if batch_idx % self.log_step == 0:
57
- self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
58
- epoch,
59
- self._progress(batch_idx),
60
- loss.item()))
61
- self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
62
-
63
- if batch_idx == self.len_epoch:
64
- break
65
- log = self.train_metrics.result()
66
-
67
- if self.do_validation:
68
- val_log = self._valid_epoch(epoch)
69
- log.update(**{'val_'+k : v for k, v in val_log.items()})
70
-
71
- if self.lr_scheduler is not None:
72
- self.lr_scheduler.step()
73
- return log
74
-
75
- def _valid_epoch(self, epoch):
76
- """
77
- Validate after training an epoch
78
-
79
- :param epoch: Integer, current training epoch.
80
- :return: A log that contains information about validation
81
- """
82
- self.model.eval()
83
- self.valid_metrics.reset()
84
- with torch.no_grad():
85
- for batch_idx, (data, target) in enumerate(self.valid_data_loader):
86
- data, target = data.to(self.device), target.to(self.device)
87
-
88
- output = self.model(data)
89
- loss = self.criterion(output, target)
90
-
91
- self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid')
92
- self.valid_metrics.update('loss', loss.item())
93
- for met in self.metric_ftns:
94
- self.valid_metrics.update(met.__name__, met(output, target))
95
- self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
96
-
97
- # add histogram of model parameters to the tensorboard
98
- for name, p in self.model.named_parameters():
99
- self.writer.add_histogram(name, p, bins='auto')
100
- return self.valid_metrics.result()
101
-
102
- def _progress(self, batch_idx):
103
- base = '[{}/{} ({:.0f}%)]'
104
- if hasattr(self.data_loader, 'n_samples'):
105
- current = batch_idx * self.data_loader.batch_size
106
- total = self.data_loader.n_samples
107
- else:
108
- current = batch_idx
109
- total = self.len_epoch
110
- return base.format(current, total, 100.0 * current / total)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/utils/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .util import *
 
 
model/utils/util.py DELETED
@@ -1,67 +0,0 @@
1
- import json
2
- import torch
3
- import pandas as pd
4
- from pathlib import Path
5
- from itertools import repeat
6
- from collections import OrderedDict
7
-
8
-
9
- def ensure_dir(dirname):
10
- dirname = Path(dirname)
11
- if not dirname.is_dir():
12
- dirname.mkdir(parents=True, exist_ok=False)
13
-
14
- def read_json(fname):
15
- fname = Path(fname)
16
- with fname.open('rt') as handle:
17
- return json.load(handle, object_hook=OrderedDict)
18
-
19
- def write_json(content, fname):
20
- fname = Path(fname)
21
- with fname.open('wt') as handle:
22
- json.dump(content, handle, indent=4, sort_keys=False)
23
-
24
- def inf_loop(data_loader):
25
- ''' wrapper function for endless data loader. '''
26
- for loader in repeat(data_loader):
27
- yield from loader
28
-
29
- def prepare_device(n_gpu_use):
30
- """
31
- setup GPU device if available. get gpu device indices which are used for DataParallel
32
- """
33
- n_gpu = torch.cuda.device_count()
34
- if n_gpu_use > 0 and n_gpu == 0:
35
- print("Warning: There\'s no GPU available on this machine,"
36
- "training will be performed on CPU.")
37
- n_gpu_use = 0
38
- if n_gpu_use > n_gpu:
39
- print(f"Warning: The number of GPU\'s configured to use is {n_gpu_use}, but only {n_gpu} are "
40
- "available on this machine.")
41
- n_gpu_use = n_gpu
42
- device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
43
- list_ids = list(range(n_gpu_use))
44
- return device, list_ids
45
-
46
- class MetricTracker:
47
- def __init__(self, *keys, writer=None):
48
- self.writer = writer
49
- self._data = pd.DataFrame(index=keys, columns=['total', 'counts', 'average'])
50
- self.reset()
51
-
52
- def reset(self):
53
- for col in self._data.columns:
54
- self._data[col].values[:] = 0
55
-
56
- def update(self, key, value, n=1):
57
- if self.writer is not None:
58
- self.writer.add_scalar(key, value)
59
- self._data.total[key] += value * n
60
- self._data.counts[key] += n
61
- self._data.average[key] = self._data.total[key] / self._data.counts[key]
62
-
63
- def avg(self, key):
64
- return self._data.average[key]
65
-
66
- def result(self):
67
- return dict(self._data.average)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/.env.example ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # this is example of the file that can be used for storing private and user specific environment variables, like keys or system paths
2
+ # create a file named .env (by default .env will be excluded from version control)
3
+ # the variables declared in .env are loaded in train.py automatically
4
+ # hydra allows you to reference variables in .yaml configs with special syntax: ${oc.env:MY_VAR}
5
+
6
+ MY_VAR="/home/user/my/system/path"
7
+ MY_KEY="asdgjhawi8y23ihsghsueity23ihwd"
models/.gitignore ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .venv
106
+ env/
107
+ venv/
108
+ ENV/
109
+ env.bak/
110
+ venv.bak/
111
+
112
+ # Spyder project settings
113
+ .spyderproject
114
+ .spyproject
115
+
116
+ # Rope project settings
117
+ .ropeproject
118
+
119
+ # mkdocs documentation
120
+ /site
121
+
122
+ # mypy
123
+ .mypy_cache/
124
+ .dmypy.json
125
+ dmypy.json
126
+
127
+ # Pyre type checker
128
+ .pyre/
129
+
130
+ ### VisualStudioCode
131
+ .vscode/*
132
+ !.vscode/settings.json
133
+ !.vscode/tasks.json
134
+ !.vscode/launch.json
135
+ !.vscode/extensions.json
136
+ *.code-workspace
137
+ **/.vscode
138
+
139
+ # JetBrains
140
+ .idea/
141
+
142
+ # Lightning-Hydra-Template
143
+ configs/local/default.yaml
144
+ data/
145
+ logs/
146
+ wandb/
147
+ .env
148
+ .autoenv
models/.pre-commit-config.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.1.0
4
+ hooks:
5
+ # list of supported hooks: https://pre-commit.com/hooks.html
6
+ - id: trailing-whitespace
7
+ - id: end-of-file-fixer
8
+ - id: check-yaml
9
+ - id: check-added-large-files
10
+ - id: debug-statements
11
+ - id: detect-private-key
12
+
13
+ # python code formatting
14
+ - repo: https://github.com/psf/black
15
+ rev: 22.1.0
16
+ hooks:
17
+ - id: black
18
+ args: [--line-length, "99"]
19
+
20
+ # python import sorting
21
+ - repo: https://github.com/PyCQA/isort
22
+ rev: 5.10.1
23
+ hooks:
24
+ - id: isort
25
+ args: ["--profile", "black", "--filter-files"]
26
+
27
+ # yaml formatting
28
+ - repo: https://github.com/pre-commit/mirrors-prettier
29
+ rev: v2.5.1
30
+ hooks:
31
+ - id: prettier
32
+ types: [yaml]
33
+
34
+ # python code analysis
35
+ - repo: https://github.com/PyCQA/flake8
36
+ rev: 4.0.1
37
+ hooks:
38
+ - id: flake8
39
+
40
+ # jupyter notebook cell output clearing
41
+ - repo: https://github.com/kynan/nbstripout
42
+ rev: 0.5.0
43
+ hooks:
44
+ - id: nbstripout
models/README.md ADDED
@@ -0,0 +1,1445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # Lightning-Hydra-Template
4
+
5
+ <a href="https://www.python.org/"><img alt="Python" src="https://img.shields.io/badge/-Python 3.7+-blue?style=for-the-badge&logo=python&logoColor=white"></a>
6
+ <a href="https://pytorch.org/get-started/locally/"><img alt="PyTorch" src="https://img.shields.io/badge/-PyTorch 1.8+-ee4c2c?style=for-the-badge&logo=pytorch&logoColor=white"></a>
7
+ <a href="https://pytorchlightning.ai/"><img alt="Lightning" src="https://img.shields.io/badge/-Lightning 1.5+-792ee5?style=for-the-badge&logo=pytorchlightning&logoColor=white"></a>
8
+ <a href="https://hydra.cc/"><img alt="Config: hydra" src="https://img.shields.io/badge/config-hydra 1.1-89b8cd?style=for-the-badge&labelColor=gray"></a>
9
+ <a href="https://black.readthedocs.io/en/stable/"><img alt="Code style: black" src="https://img.shields.io/badge/code%20style-black-black.svg?style=for-the-badge&labelColor=gray"></a>
10
+
11
+ A clean and scalable template to kickstart your deep learning project 🚀⚡🔥<br>
12
+ Click on [<kbd>Use this template</kbd>](https://github.com/ashleve/lightning-hydra-template/generate) to initialize new repository.
13
+
14
+ _Suggestions are always welcome!_
15
+
16
+ </div>
17
+
18
+ <br><br>
19
+
20
+ ## 📌&nbsp;&nbsp;Introduction
21
+
22
+ This template tries to be as general as possible. It integrates many different MLOps tools.
23
+
24
+ > Effective usage of this template requires learning of a couple of technologies: [PyTorch](https://pytorch.org), [PyTorch Lightning](https://www.pytorchlightning.ai) and [Hydra](https://hydra.cc). Knowledge of some experiment logging framework like [Weights&Biases](https://wandb.com), [Neptune](https://neptune.ai) or [MLFlow](https://mlflow.org) is also recommended.
25
+
26
+ **Why you should use it:** it allows you to rapidly iterate over new models/datasets and scale your projects from small single experiments to hyperparameter searches on computing clusters, without writing any boilerplate code. To my knowledge, it's one of the most convenient all-in-one technology stack for Deep Learning research. Good starting point for reproducing papers, kaggle competitions or small-team research projects. It's also a collection of best practices for efficient workflow and reproducibility.
27
+
28
+ **Why you shouldn't use it:** this template is not fitted to be a production environment, should be used more as a fast experimentation tool. Apart from that, Lightning and Hydra are still evolving and integrate many libraries, which means sometimes things break - for the list of currently known bugs, visit [this page](https://github.com/ashleve/lightning-hydra-template/labels/bug). Also, even though Lightning is very flexible, it's not well suited for every possible deep learning task. See [#Limitations](#limitations) for more.
29
+
30
+ ### Why PyTorch Lightning?
31
+
32
+ [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning) is a lightweight PyTorch wrapper for high-performance AI research.
33
+ It makes your code neatly organized and provides lots of useful features, like ability to run model on CPU, GPU, multi-GPU cluster and TPU.
34
+
35
+ ### Why Hydra?
36
+
37
+ [Hydra](https://github.com/facebookresearch/hydra) is an open-source Python framework that simplifies the development of research and other complex applications. The key feature is the ability to dynamically create a hierarchical configuration by composition and override it through config files and the command line. It allows you to conveniently manage experiments and provides many useful plugins, like [Optuna Sweeper](https://hydra.cc/docs/next/plugins/optuna_sweeper) for hyperparameter search, or [Ray Launcher](https://hydra.cc/docs/next/plugins/ray_launcher) for running jobs on a cluster.
38
+
39
+ <br>
40
+
41
+ ## Main Ideas Of This Template
42
+
43
+ - **Predefined Structure**: clean and scalable so that work can easily be extended and replicated | [#Project Structure](#project-structure)
44
+ - **Rapid Experimentation**: thanks to automating pipeline with config files and hydra command line superpowers | [#Your Superpowers](#your-superpowers)
45
+ - **Reproducibility**: obtaining similar results is supported in multiple ways | [#Reproducibility](#reproducibility)
46
+ - **Little Boilerplate**: so pipeline can be easily modified | [#How It Works](#how-it-works)
47
+ - **Main Configuration**: main config file specifies default training configuration | [#Main Project Configuration](#main-project-configuration)
48
+ - **Experiment Configurations**: can be composed out of smaller configs and override chosen hyperparameters | [#Experiment Configuration](#experiment-configuration)
49
+ - **Workflow**: comes down to 4 simple steps | [#Workflow](#workflow)
50
+ - **Experiment Tracking**: many logging frameworks can be easily integrated, like Tensorboard, MLFlow or W&B | [#Experiment Tracking](#experiment-tracking)
51
+ - **Logs**: all logs (checkpoints, data from loggers, hparams, etc.) are stored in a convenient folder structure imposed by Hydra | [#Logs](#logs)
52
+ - **Hyperparameter Search**: made easier with Hydra built-in plugins like [Optuna Sweeper](https://hydra.cc/docs/next/plugins/optuna_sweeper) | [#Hyperparameter Search](#hyperparameter-search)
53
+ - **Tests**: unit tests and shell/command based tests for speeding up the development | [#Tests](#tests)
54
+ - **Best Practices**: a couple of recommended tools, practices and standards for efficient workflow and reproducibility | [#Best Practices](#best-practices)
55
+
56
+ <br>
57
+
58
+ ## Project Structure
59
+
60
+ The directory structure of new project looks like this:
61
+
62
+ ```
63
+ ├── configs <- Hydra configuration files
64
+ │ ├── callbacks <- Callbacks configs
65
+ │ ├── datamodule <- Datamodule configs
66
+ │ ├── debug <- Debugging configs
67
+ │ ├── experiment <- Experiment configs
68
+ │ ├── hparams_search <- Hyperparameter search configs
69
+ │ ├── local <- Local configs
70
+ │ ├── log_dir <- Logging directory configs
71
+ │ ├── logger <- Logger configs
72
+ │ ├── model <- Model configs
73
+ │ ├── trainer <- Trainer configs
74
+ │ │
75
+ │ ├── test.yaml <- Main config for testing
76
+ │ └── train.yaml <- Main config for training
77
+
78
+ ├── data <- Project data
79
+
80
+ ├── logs <- Logs generated by Hydra and PyTorch Lightning loggers
81
+
82
+ ├── notebooks <- Jupyter notebooks. Naming convention is a number (for ordering),
83
+ │ the creator's initials, and a short `-` delimited description,
84
+ │ e.g. `1.0-jqp-initial-data-exploration.ipynb`.
85
+
86
+ ├── scripts <- Shell scripts
87
+
88
+ ├── src <- Source code
89
+ │ ├── datamodules <- Lightning datamodules
90
+ │ ├── models <- Lightning models
91
+ │ ├── utils <- Utility scripts
92
+ │ ├── vendor <- Third party code that cannot be installed using PIP/Conda
93
+ │ │
94
+ │ ├── testing_pipeline.py
95
+ │ └── training_pipeline.py
96
+
97
+ ├── tests <- Tests of any kind
98
+ │ ├── helpers <- A couple of testing utilities
99
+ │ ├── shell <- Shell/command based tests
100
+ │ └── unit <- Unit tests
101
+
102
+ ├── test.py <- Run testing
103
+ ├── train.py <- Run training
104
+
105
+ ├── .env.example <- Template of the file for storing private environment variables
106
+ ├── .gitignore <- List of files/folders ignored by git
107
+ ├── .pre-commit-config.yaml <- Configuration of pre-commit hooks for code formatting
108
+ ├── requirements.txt <- File for installing python dependencies
109
+ ├── setup.cfg <- Configuration of linters and pytest
110
+ └── README.md
111
+ ```
112
+
113
+ <br>
114
+
115
+ ## 🚀&nbsp;&nbsp;Quickstart
116
+
117
+ ```bash
118
+ # clone project
119
+ git clone https://github.com/ashleve/lightning-hydra-template
120
+ cd lightning-hydra-template
121
+
122
+ # [OPTIONAL] create conda environment
123
+ conda create -n myenv python=3.8
124
+ conda activate myenv
125
+
126
+ # install pytorch according to instructions
127
+ # https://pytorch.org/get-started/
128
+
129
+ # install requirements
130
+ pip install -r requirements.txt
131
+ ```
132
+
133
+ Template contains example with MNIST classification.<br>
134
+ When running `python train.py` you should see something like this:
135
+
136
+ <div align="center">
137
+
138
+ ![](https://github.com/ashleve/lightning-hydra-template/blob/resources/terminal.png)
139
+
140
+ </div>
141
+
142
+ ### ⚡&nbsp;&nbsp;Your Superpowers
143
+
144
+ <details>
145
+ <summary><b>Override any config parameter from command line</b></summary>
146
+
147
+ > Hydra allows you to easily overwrite any parameter defined in your config.
148
+
149
+ ```bash
150
+ python train.py trainer.max_epochs=20 model.lr=1e-4
151
+ ```
152
+
153
+ > You can also add new parameters with `+` sign.
154
+
155
+ ```bash
156
+ python train.py +model.new_param="uwu"
157
+ ```
158
+
159
+ </details>
160
+
161
+ <details>
162
+ <summary><b>Train on CPU, GPU, multi-GPU and TPU</b></summary>
163
+
164
+ > PyTorch Lightning makes it easy to train your models on different hardware.
165
+
166
+ ```bash
167
+ # train on CPU
168
+ python train.py trainer.gpus=0
169
+
170
+ # train on 1 GPU
171
+ python train.py trainer.gpus=1
172
+
173
+ # train on TPU
174
+ python train.py +trainer.tpu_cores=8
175
+
176
+ # train with DDP (Distributed Data Parallel) (4 GPUs)
177
+ python train.py trainer.gpus=4 +trainer.strategy=ddp
178
+
179
+ # train with DDP (Distributed Data Parallel) (8 GPUs, 2 nodes)
180
+ python train.py trainer.gpus=4 +trainer.num_nodes=2 +trainer.strategy=ddp
181
+ ```
182
+
183
+ </details>
184
+
185
+ <details>
186
+ <summary><b>Train with mixed precision</b></summary>
187
+
188
+ ```bash
189
+ # train with pytorch native automatic mixed precision (AMP)
190
+ python train.py trainer.gpus=1 +trainer.precision=16
191
+ ```
192
+
193
+ </details>
194
+
195
+ <!-- deepspeed support still in beta
196
+ <details>
197
+ <summary><b>Optimize large scale models on multiple GPUs with Deepspeed</b></summary>
198
+
199
+ ```bash
200
+ python train.py +trainer.
201
+ ```
202
+
203
+ </details>
204
+ -->
205
+
206
+ <details>
207
+ <summary><b>Train model with any logger available in PyTorch Lightning, like Weights&Biases or Tensorboard</b></summary>
208
+
209
+ > PyTorch Lightning provides convenient integrations with most popular logging frameworks, like Tensorboard, Neptune or simple csv files. Read more [here](#experiment-tracking). Using wandb requires you to [setup account](https://www.wandb.com/) first. After that just complete the config as below.<br> > **Click [here](https://wandb.ai/hobglob/template-dashboard/) to see example wandb dashboard generated with this template.**
210
+
211
+ ```bash
212
+ # set project and entity names in `configs/logger/wandb`
213
+ wandb:
214
+ project: "your_project_name"
215
+ entity: "your_wandb_team_name"
216
+ ```
217
+
218
+ ```bash
219
+ # train model with Weights&Biases (link to wandb dashboard should appear in the terminal)
220
+ python train.py logger=wandb
221
+ ```
222
+
223
+ </details>
224
+
225
+ <details>
226
+ <summary><b>Train model with chosen experiment config</b></summary>
227
+
228
+ > Experiment configurations are placed in [configs/experiment/](configs/experiment/).
229
+
230
+ ```bash
231
+ python train.py experiment=example
232
+ ```
233
+
234
+ </details>
235
+
236
+ <details>
237
+ <summary><b>Attach some callbacks to run</b></summary>
238
+
239
+ > Callbacks can be used for things such as as model checkpointing, early stopping and [many more](https://pytorch-lightning.readthedocs.io/en/latest/extensions/callbacks.html#built-in-callbacks).<br>
240
+ > Callbacks configurations are placed in [configs/callbacks/](configs/callbacks/).
241
+
242
+ ```bash
243
+ python train.py callbacks=default
244
+ ```
245
+
246
+ </details>
247
+
248
+ <details>
249
+ <summary><b>Use different tricks available in Pytorch Lightning</b></summary>
250
+
251
+ > PyTorch Lightning provides about [40+ useful trainer flags](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags).
252
+
253
+ ```yaml
254
+ # gradient clipping may be enabled to avoid exploding gradients
255
+ python train.py +trainer.gradient_clip_val=0.5
256
+
257
+ # stochastic weight averaging can make your models generalize better
258
+ python train.py +trainer.stochastic_weight_avg=true
259
+
260
+ # run validation loop 4 times during a training epoch
261
+ python train.py +trainer.val_check_interval=0.25
262
+
263
+ # accumulate gradients
264
+ python train.py +trainer.accumulate_grad_batches=10
265
+
266
+ # terminate training after 12 hours
267
+ python train.py +trainer.max_time="00:12:00:00"
268
+ ```
269
+
270
+ </details>
271
+
272
+ <details>
273
+ <summary><b>Easily debug</b></summary>
274
+
275
+ > Visit [configs/debug/](configs/debug/) for different debugging configs.
276
+
277
+ ```bash
278
+ # runs 1 epoch in default debugging mode
279
+ # changes logging directory to `logs/debugs/...`
280
+ # sets level of all command line loggers to 'DEBUG'
281
+ # enables extra trainer flags like tracking gradient norm
282
+ # enforces debug-friendly configuration
283
+ python train.py debug=default
284
+
285
+ # runs test epoch without training
286
+ python train.py debug=test_only
287
+
288
+ # run 1 train, val and test loop, using only 1 batch
289
+ python train.py +trainer.fast_dev_run=true
290
+
291
+ # raise exception if there are any numerical anomalies in tensors, like NaN or +/-inf
292
+ python train.py +trainer.detect_anomaly=true
293
+
294
+ # print execution time profiling after training ends
295
+ python train.py +trainer.profiler="simple"
296
+
297
+ # try overfitting to 1 batch
298
+ python train.py +trainer.overfit_batches=1 trainer.max_epochs=20
299
+
300
+ # use only 20% of the data
301
+ python train.py +trainer.limit_train_batches=0.2 \
302
+ +trainer.limit_val_batches=0.2 +trainer.limit_test_batches=0.2
303
+
304
+ # log second gradient norm of the model
305
+ python train.py +trainer.track_grad_norm=2
306
+ ```
307
+
308
+ </details>
309
+
310
+ <details>
311
+ <summary><b>Resume training from checkpoint</b></summary>
312
+
313
+ > Checkpoint can be either path or URL.
314
+
315
+ ```yaml
316
+ python train.py trainer.resume_from_checkpoint="/path/to/ckpt/name.ckpt"
317
+ ```
318
+
319
+ > ⚠️ Currently loading ckpt in Lightning doesn't resume logger experiment, but it will be supported in future Lightning release.
320
+
321
+ </details>
322
+
323
+ <details>
324
+ <summary><b>Execute evaluation for a given checkpoint</b></summary>
325
+
326
+ > Checkpoint can be either path or URL.
327
+
328
+ ```yaml
329
+ python test.py ckpt_path="/path/to/ckpt/name.ckpt"
330
+ ```
331
+
332
+ </details>
333
+
334
+ <details>
335
+ <summary><b>Create a sweep over hyperparameters</b></summary>
336
+
337
+ ```bash
338
+ # this will run 6 experiments one after the other,
339
+ # each with different combination of batch_size and learning rate
340
+ python train.py -m datamodule.batch_size=32,64,128 model.lr=0.001,0.0005
341
+ ```
342
+
343
+ > ⚠️ This sweep is not failure resistant (if one job crashes than the whole sweep crashes).
344
+
345
+ </details>
346
+
347
+ <details>
348
+ <summary><b>Create a sweep over hyperparameters with Optuna</b></summary>
349
+
350
+ > Using [Optuna Sweeper](https://hydra.cc/docs/next/plugins/optuna_sweeper) plugin doesn't require you to code any boilerplate into your pipeline, everything is defined in a [single config file](configs/hparams_search/mnist_optuna.yaml)!
351
+
352
+ ```bash
353
+ # this will run hyperparameter search defined in `configs/hparams_search/mnist_optuna.yaml`
354
+ # over chosen experiment config
355
+ python train.py -m hparams_search=mnist_optuna experiment=example_simple
356
+ ```
357
+
358
+ > ⚠️ Currently this sweep is not failure resistant (if one job crashes than the whole sweep crashes). Might be supported in future Hydra release.
359
+
360
+ </details>
361
+
362
+ <details>
363
+ <summary><b>Execute all experiments from folder</b></summary>
364
+
365
+ > Hydra provides special syntax for controlling behavior of multiruns. Learn more [here](https://hydra.cc/docs/next/tutorials/basic/running_your_app/multi-run). The command below executes all experiments from folder [configs/experiment/](configs/experiment/).
366
+
367
+ ```bash
368
+ python train.py -m 'experiment=glob(*)'
369
+ ```
370
+
371
+ </details>
372
+
373
+ <details>
374
+ <summary><b>Execute sweep on a remote AWS cluster</b></summary>
375
+
376
+ > This should be achievable with simple config using [Ray AWS launcher for Hydra](https://hydra.cc/docs/next/plugins/ray_launcher). Example is not yet implemented in this template.
377
+
378
+ </details>
379
+
380
+ <!-- <details>
381
+ <summary><b>Execute sweep on a SLURM cluster</b></summary>
382
+
383
+ > This should be achievable with either [the right lightning trainer flags](https://pytorch-lightning.readthedocs.io/en/latest/clouds/cluster.html?highlight=SLURM#slurm-managed-cluster) or simple config using [Submitit launcher for Hydra](https://hydra.cc/docs/plugins/submitit_launcher). Example is not yet implemented in this template.
384
+
385
+ </details> -->
386
+
387
+ <details>
388
+ <summary><b>Use Hydra tab completion</b></summary>
389
+
390
+ > Hydra allows you to autocomplete config argument overrides in shell as you write them, by pressing `tab` key. Learn more [here](https://hydra.cc/docs/tutorials/basic/running_your_app/tab_completion).
391
+
392
+ </details>
393
+
394
+ <details>
395
+ <summary><b>Apply pre-commit hooks</b></summary>
396
+
397
+ > Apply pre-commit hooks to automatically format your code and configs, perform code analysis and remove output from jupyter notebooks. See [# Best Practices](#best-practices) for more.
398
+
399
+ ```bash
400
+ pre-commit run -a
401
+ ```
402
+
403
+ </details>
404
+
405
+ <br>
406
+
407
+ ## ❤️&nbsp;&nbsp;Contributions
408
+
409
+ Have a question? Found a bug? Missing a specific feature? Have an idea for improving documentation? Feel free to file a new issue, discussion or PR with respective title and description. If you already found a solution to your problem, don't hesitate to share it. Suggestions for new best practices are always welcome!
410
+
411
+ <br>
412
+
413
+ ## ℹ️&nbsp;&nbsp;Guide
414
+
415
+ ### How To Get Started
416
+
417
+ - First, you should probably get familiar with [PyTorch Lightning](https://www.pytorchlightning.ai)
418
+ - Next, go through [Hydra quick start guide](https://hydra.cc/docs/intro/) and [basic Hydra tutorial](https://hydra.cc/docs/tutorials/basic/your_first_app/simple_cli/)
419
+
420
+ <br>
421
+
422
+ ### How It Works
423
+
424
+ All PyTorch Lightning modules are dynamically instantiated from module paths specified in config. Example model config:
425
+
426
+ ```yaml
427
+ _target_: src.models.mnist_model.MNISTLitModule
428
+ input_size: 784
429
+ lin1_size: 256
430
+ lin2_size: 256
431
+ lin3_size: 256
432
+ output_size: 10
433
+ lr: 0.001
434
+ ```
435
+
436
+ Using this config we can instantiate the object with the following line:
437
+
438
+ ```python
439
+ model = hydra.utils.instantiate(config.model)
440
+ ```
441
+
442
+ This allows you to easily iterate over new models! Every time you create a new one, just specify its module path and parameters in appropriate config file. <br>
443
+
444
+ Switch between models and datamodules with command line arguments:
445
+
446
+ ```bash
447
+ python train.py model=mnist
448
+ ```
449
+
450
+ The whole pipeline managing the instantiation logic is placed in [src/training_pipeline.py](src/training_pipeline.py).
451
+
452
+ <br>
453
+
454
+ ### Main Project Configuration
455
+
456
+ Location: [configs/train.yaml](configs/train.yaml) <br>
457
+ Main project config contains default training configuration.<br>
458
+ It determines how config is composed when simply executing command `python train.py`.<br>
459
+
460
+ <details>
461
+ <summary><b>Show main project config</b></summary>
462
+
463
+ ```yaml
464
+ # specify here default training configuration
465
+ defaults:
466
+ - _self_
467
+ - datamodule: mnist.yaml
468
+ - model: mnist.yaml
469
+ - callbacks: default.yaml
470
+ - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
471
+ - trainer: default.yaml
472
+ - log_dir: default.yaml
473
+
474
+ # experiment configs allow for version control of specific configurations
475
+ # e.g. best hyperparameters for each combination of model and datamodule
476
+ - experiment: null
477
+
478
+ # debugging config (enable through command line, e.g. `python train.py debug=default)
479
+ - debug: null
480
+
481
+ # config for hyperparameter optimization
482
+ - hparams_search: null
483
+
484
+ # optional local config for machine/user specific settings
485
+ # it's optional since it doesn't need to exist and is excluded from version control
486
+ - optional local: default.yaml
487
+
488
+ # enable color logging
489
+ - override hydra/hydra_logging: colorlog
490
+ - override hydra/job_logging: colorlog
491
+
492
+ # path to original working directory
493
+ # hydra hijacks working directory by changing it to the new log directory
494
+ # https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory
495
+ original_work_dir: ${hydra:runtime.cwd}
496
+
497
+ # path to folder with data
498
+ data_dir: ${original_work_dir}/data/
499
+
500
+ # pretty print config at the start of the run using Rich library
501
+ print_config: True
502
+
503
+ # disable python warnings if they annoy you
504
+ ignore_warnings: True
505
+
506
+ # set False to skip model training
507
+ train: True
508
+
509
+ # evaluate on test set, using best model weights achieved during training
510
+ # lightning chooses best weights based on the metric specified in checkpoint callback
511
+ test: True
512
+
513
+ # seed for random number generators in pytorch, numpy and python.random
514
+ seed: null
515
+
516
+ # default name for the experiment, determines logging folder path
517
+ # (you can overwrite this name in experiment configs)
518
+ name: "default"
519
+ ```
520
+
521
+ </details>
522
+
523
+ <br>
524
+
525
+ ### Experiment Configuration
526
+
527
+ Location: [configs/experiment](configs/experiment)<br>
528
+ Experiment configs allow you to overwrite parameters from main project configuration.<br>
529
+ For example, you can use them to version control best hyperparameters for each combination of model and dataset.
530
+
531
+ <details>
532
+ <summary><b>Show example experiment config</b></summary>
533
+
534
+ ```yaml
535
+ # to execute this experiment run:
536
+ # python train.py experiment=example
537
+
538
+ defaults:
539
+ - override /datamodule: mnist.yaml
540
+ - override /model: mnist.yaml
541
+ - override /callbacks: default.yaml
542
+ - override /logger: null
543
+ - override /trainer: default.yaml
544
+
545
+ # all parameters below will be merged with parameters from default configurations set above
546
+ # this allows you to overwrite only specified parameters
547
+
548
+ # name of the run determines folder name in logs
549
+ name: "simple_dense_net"
550
+
551
+ seed: 12345
552
+
553
+ trainer:
554
+ min_epochs: 10
555
+ max_epochs: 10
556
+ gradient_clip_val: 0.5
557
+
558
+ model:
559
+ lin1_size: 128
560
+ lin2_size: 256
561
+ lin3_size: 64
562
+ lr: 0.002
563
+
564
+ datamodule:
565
+ batch_size: 64
566
+
567
+ logger:
568
+ wandb:
569
+ tags: ["mnist", "${name}"]
570
+ ```
571
+
572
+ </details>
573
+
574
+ <br>
575
+
576
+ ### Local Configuration
577
+
578
+ Location: [configs/local](configs/local) <br>
579
+ Some configurations are user/machine/installation specific (e.g. configuration of local cluster, or harddrive paths on a specific machine). For such scenarios, a file `configs/local/default.yaml` can be created which is automatically loaded but not tracked by Git.
580
+
581
+ <details>
582
+ <summary><b>Show example local Slurm cluster config</b></summary>
583
+
584
+ ```yaml
585
+ # @package _global_
586
+
587
+ defaults:
588
+ - override /hydra/launcher@_here_: submitit_slurm
589
+
590
+ data_dir: /mnt/scratch/data/
591
+
592
+ hydra:
593
+ launcher:
594
+ timeout_min: 1440
595
+ gpus_per_task: 1
596
+ gres: gpu:1
597
+ job:
598
+ env_set:
599
+ MY_VAR: /home/user/my/system/path
600
+ MY_KEY: asdgjhawi8y23ihsghsueity23ihwd
601
+ ```
602
+
603
+ </details>
604
+
605
+ <br>
606
+
607
+ ### Workflow
608
+
609
+ 1. Write your PyTorch Lightning module (see [models/mnist_module.py](src/models/mnist_module.py) for example)
610
+ 2. Write your PyTorch Lightning datamodule (see [datamodules/mnist_datamodule.py](src/datamodules/mnist_datamodule.py) for example)
611
+ 3. Write your experiment config, containing paths to your model and datamodule
612
+ 4. Run training with chosen experiment config: `python train.py experiment=experiment_name`
613
+
614
+ <br>
615
+
616
+ ### Logs
617
+
618
+ **Hydra creates new working directory for every executed run.** By default, logs have the following structure:
619
+
620
+ ```
621
+ ├── logs
622
+ │ ├── experiments # Folder for the logs generated by experiments
623
+ │ │ ├── runs # Folder for single runs
624
+ │ │ │ ├── experiment_name # Experiment name
625
+ │ │ │ │ ├── YYYY-MM-DD_HH-MM-SS # Datetime of the run
626
+ │ │ │ │ │ ├── .hydra # Hydra logs
627
+ │ │ │ │ │ ├── csv # Csv logs
628
+ │ │ │ │ │ ├── wandb # Weights&Biases logs
629
+ │ │ │ │ │ ├── checkpoints # Training checkpoints
630
+ │ │ │ │ │ └── ... # Any other thing saved during training
631
+ │ │ │ │ └── ...
632
+ │ │ │ └── ...
633
+ │ │ │
634
+ │ │ └── multiruns # Folder for multiruns
635
+ │ │ ├── experiment_name # Experiment name
636
+ │ │ │ ├── YYYY-MM-DD_HH-MM-SS # Datetime of the multirun
637
+ │ │ │ │ ├──1 # Multirun job number
638
+ │ │ │ │ ├──2
639
+ │ │ │ │ └── ...
640
+ │ │ │ └── ...
641
+ │ │ └── ...
642
+ │ │
643
+ │ ├── evaluations # Folder for the logs generated during testing
644
+ │ │ └── ...
645
+ │ │
646
+ │ └── debugs # Folder for the logs generated during debugging
647
+ │ └── ...
648
+ ```
649
+
650
+ You can change this structure by modifying paths in [hydra configuration](configs/log_dir).
651
+
652
+ <br>
653
+
654
+ ### Experiment Tracking
655
+
656
+ PyTorch Lightning supports many popular logging frameworks:<br>
657
+ **[Weights&Biases](https://www.wandb.com/) · [Neptune](https://neptune.ai/) · [Comet](https://www.comet.ml/) · [MLFlow](https://mlflow.org) · [Tensorboard](https://www.tensorflow.org/tensorboard/)**
658
+
659
+ These tools help you keep track of hyperparameters and output metrics and allow you to compare and visualize results. To use one of them simply complete its configuration in [configs/logger](configs/logger) and run:
660
+
661
+ ```bash
662
+ python train.py logger=logger_name
663
+ ```
664
+
665
+ You can use many of them at once (see [configs/logger/many_loggers.yaml](configs/logger/many_loggers.yaml) for example).
666
+
667
+ You can also write your own logger.
668
+
669
+ Lightning provides convenient method for logging custom metrics from inside LightningModule. Read the docs [here](https://pytorch-lightning.readthedocs.io/en/latest/extensions/logging.html#automatic-logging) or take a look at [MNIST example](src/models/mnist_module.py).
670
+
671
+ <br>
672
+
673
+ ### Hyperparameter Search
674
+
675
+ Defining hyperparameter optimization is as easy as adding new config file to [configs/hparams_search](configs/hparams_search).
676
+
677
+ <details>
678
+ <summary><b>Show example</b></summary>
679
+
680
+ ```yaml
681
+ defaults:
682
+ - override /hydra/sweeper: optuna
683
+
684
+ # choose metric which will be optimized by Optuna
685
+ optimized_metric: "val/acc_best"
686
+
687
+ hydra:
688
+ # here we define Optuna hyperparameter search
689
+ # it optimizes for value returned from function with @hydra.main decorator
690
+ # learn more here: https://hydra.cc/docs/next/plugins/optuna_sweeper
691
+ sweeper:
692
+ _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
693
+ storage: null
694
+ study_name: null
695
+ n_jobs: 1
696
+
697
+ # 'minimize' or 'maximize' the objective
698
+ direction: maximize
699
+
700
+ # number of experiments that will be executed
701
+ n_trials: 20
702
+
703
+ # choose Optuna hyperparameter sampler
704
+ # learn more here: https://optuna.readthedocs.io/en/stable/reference/samplers.html
705
+ sampler:
706
+ _target_: optuna.samplers.TPESampler
707
+ seed: 12345
708
+ consider_prior: true
709
+ prior_weight: 1.0
710
+ consider_magic_clip: true
711
+ consider_endpoints: false
712
+ n_startup_trials: 10
713
+ n_ei_candidates: 24
714
+ multivariate: false
715
+ warn_independent_sampling: true
716
+
717
+ # define range of hyperparameters
718
+ search_space:
719
+ datamodule.batch_size:
720
+ type: categorical
721
+ choices: [32, 64, 128]
722
+ model.lr:
723
+ type: float
724
+ low: 0.0001
725
+ high: 0.2
726
+ model.lin1_size:
727
+ type: categorical
728
+ choices: [32, 64, 128, 256, 512]
729
+ model.lin2_size:
730
+ type: categorical
731
+ choices: [32, 64, 128, 256, 512]
732
+ model.lin3_size:
733
+ type: categorical
734
+ choices: [32, 64, 128, 256, 512]
735
+ ```
736
+
737
+ </details>
738
+
739
+ Next, you can execute it with: `python train.py -m hparams_search=mnist_optuna`
740
+
741
+ Using this approach doesn't require you to add any boilerplate into your pipeline, everything is defined in a single config file.
742
+
743
+ You can use different optimization frameworks integrated with Hydra, like Optuna, Ax or Nevergrad.
744
+
745
+ The `optimization_results.yaml` will be available under `logs/multirun` folder.
746
+
747
+ This approach doesn't support advanced technics like prunning - for more sophisticated search, you probably shouldn't use hydra multirun feature and instead write your own optimization pipeline.
748
+
749
+ <br>
750
+
751
+ ### Inference
752
+
753
+ The following code is an example of loading model from checkpoint and running predictions.<br>
754
+
755
+ <details>
756
+ <summary><b>Show example</b></summary>
757
+
758
+ ```python
759
+ from PIL import Image
760
+ from torchvision import transforms
761
+
762
+ from src.models.mnist_module import MNISTLitModule
763
+
764
+
765
+ def predict():
766
+ """Example of inference with trained model.
767
+ It loads trained image classification model from checkpoint.
768
+ Then it loads example image and predicts its label.
769
+ """
770
+
771
+ # ckpt can be also a URL!
772
+ CKPT_PATH = "last.ckpt"
773
+
774
+ # load model from checkpoint
775
+ # model __init__ parameters will be loaded from ckpt automatically
776
+ # you can also pass some parameter explicitly to override it
777
+ trained_model = MNISTLitModule.load_from_checkpoint(checkpoint_path=CKPT_PATH)
778
+
779
+ # print model hyperparameters
780
+ print(trained_model.hparams)
781
+
782
+ # switch to evaluation mode
783
+ trained_model.eval()
784
+ trained_model.freeze()
785
+
786
+ # load data
787
+ img = Image.open("data/example_img.png").convert("L") # convert to black and white
788
+ # img = Image.open("data/example_img.png").convert("RGB") # convert to RGB
789
+
790
+ # preprocess
791
+ mnist_transforms = transforms.Compose(
792
+ [
793
+ transforms.ToTensor(),
794
+ transforms.Resize((28, 28)),
795
+ transforms.Normalize((0.1307,), (0.3081,)),
796
+ ]
797
+ )
798
+ img = mnist_transforms(img)
799
+ img = img.reshape((1, *img.size())) # reshape to form batch of size 1
800
+
801
+ # inference
802
+ output = trained_model(img)
803
+ print(output)
804
+
805
+
806
+ if __name__ == "__main__":
807
+ predict()
808
+
809
+ ```
810
+
811
+ </details>
812
+
813
+ <br>
814
+
815
+ ### Tests
816
+
817
+ Template comes with example tests implemented with pytest library. To execute them simply run:
818
+
819
+ ```bash
820
+ # run all tests
821
+ pytest
822
+
823
+ # run tests from specific file
824
+ pytest tests/shell/test_basic_commands.py
825
+
826
+ # run all tests except the ones marked as slow
827
+ pytest -k "not slow"
828
+ ```
829
+
830
+ To speed up the development, you can once in a while execute tests that run a couple of quick experiments, like training 1 epoch on 25% of data, executing single train/val/test step, etc. Those kind of tests don't check for any specific output - they exist to simply verify that executing some bash commands doesn't end up in throwing exceptions. You can find them implemented in [tests/shell](tests/shell) folder.
831
+
832
+ You can easily modify the commands in the scripts for your use case. If 1 epoch is too much for your model, then make it run for a couple of batches instead (by using the right trainer flags).
833
+
834
+ <br>
835
+
836
+ ### Callbacks
837
+
838
+ The branch [`wandb-callbacks`](https://github.com/ashleve/lightning-hydra-template/tree/wandb-callbacks) contains example callbacks enabling better Weights&Biases integration, which you can use as a reference for writing your own callbacks (see [wandb_callbacks.py](https://github.com/ashleve/lightning-hydra-template/tree/wandb-callbacks/src/callbacks/wandb_callbacks.py)).
839
+
840
+ Callbacks which support reproducibility:
841
+
842
+ - **WatchModel**
843
+ - **UploadCodeAsArtifact**
844
+ - **UploadCheckpointsAsArtifact**
845
+
846
+ Callbacks which provide examples of logging custom visualisations:
847
+
848
+ - **LogConfusionMatrix**
849
+ - **LogF1PrecRecHeatmap**
850
+ - **LogImagePredictions**
851
+
852
+ To try all of the callbacks at once, switch to the right branch:
853
+
854
+ ```bash
855
+ git checkout wandb-callbacks
856
+ ```
857
+
858
+ And then run the following command:
859
+
860
+ ```bash
861
+ python train.py logger=wandb callbacks=wandb
862
+ ```
863
+
864
+ To see the result of all the callbacks attached, take a look at [this experiment dashboard](https://wandb.ai/hobglob/template-tests/runs/3rw7q70h).
865
+
866
+ <br>
867
+
868
+ ### Multi-GPU Training
869
+
870
+ Lightning supports multiple ways of doing distributed training. The most common one is DDP, which spawns separate process for each GPU and averages gradients between them. To learn about other approaches read the [lightning docs](https://pytorch-lightning.readthedocs.io/en/latest/advanced/multi_gpu.html).
871
+
872
+ You can run DDP on mnist example with 4 GPUs like this:
873
+
874
+ ```bash
875
+ python train.py trainer.gpus=4 +trainer.strategy=ddp
876
+ ```
877
+
878
+ ⚠️ When using DDP you have to be careful how you write your models - learn more [here](https://pytorch-lightning.readthedocs.io/en/latest/advanced/multi_gpu.html).
879
+
880
+ <br>
881
+
882
+ ### Docker
883
+
884
+ First you will need to [install Nvidia Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to enable GPU support.
885
+
886
+ The template Dockerfile is provided on branch [`dockerfiles`](https://github.com/ashleve/lightning-hydra-template/tree/dockerfiles). Copy it to the template root folder.
887
+
888
+ To build the container use:
889
+
890
+ ```bash
891
+ docker build -t <project_name> .
892
+ ```
893
+
894
+ To mount the project to the container use:
895
+
896
+ ```bash
897
+ docker run -v $(pwd):/workspace/project --gpus all -it --rm <project_name>
898
+ ```
899
+
900
+ <br>
901
+
902
+ ### Reproducibility
903
+
904
+ What provides reproducibility:
905
+
906
+ - Hydra manages your configs
907
+ - Hydra manages your logging paths and makes every executed run store its hyperparameters and config overrides in a separate file in logs
908
+ - Single seed for random number generators in pytorch, numpy and python.random
909
+ - LightningDataModule allows you to encapsulate data split, transformations and default parameters in a single, clean abstraction
910
+ - LightningModule separates your research code from engineering code in a clean way
911
+ - Experiment tracking frameworks take care of logging metrics and hparams, some can also store results and artifacts in cloud
912
+ - Pytorch Lightning takes care of creating training checkpoints
913
+ - Example callbacks for wandb show how you can save and upload a snapshot of codebase every time the run is executed, as well as upload ckpts and track model gradients
914
+
915
+ <!--
916
+ You can load the config of previous run using:
917
+
918
+ ```bash
919
+ python train.py --config-path /logs/runs/.../.hydra/ --config-name config.yaml
920
+ ```
921
+
922
+ The `config.yaml` from `.hydra` folder contains all overriden parameters and sections. This approach however is not officially supported by Hydra and doesn't override the `hydra/` part of the config, meaning logging paths will revert to default!
923
+ -->
924
+ <br>
925
+
926
+ ### Limitations
927
+
928
+ - Currently, template doesn't support k-fold cross validation, but it's possible to achieve it with Lightning Loop interface. See the [official example](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/loop_examples/kfold.py). Implementing it requires rewriting the training pipeline.
929
+ - Pytorch Lightning might not be the best choice for scalable reinforcement learning, it's probably better to use something like [Ray](https://github.com/ray-project/ray).
930
+ - Currently hyperparameter search with Hydra Optuna Plugin doesn't support prunning.
931
+ - Hydra changes working directory to new logging folder for every executed run, which might not be compatible with the way some libraries work.
932
+
933
+ <br>
934
+
935
+ ## Useful Tricks
936
+
937
+ <details>
938
+ <summary><b>Accessing datamodule attributes in model</b></summary>
939
+
940
+ 1. The simplest way is to pass datamodule attribute directly to model on initialization:
941
+
942
+ ```python
943
+ # ./src/training_pipeline.py
944
+ datamodule = hydra.utils.instantiate(config.datamodule)
945
+ model = hydra.utils.instantiate(config.model, some_param=datamodule.some_param)
946
+ ```
947
+
948
+ This is not a very robust solution, since it assumes all your datamodules have `some_param` attribute available (otherwise the run will crash).
949
+
950
+ 2. If you only want to access datamodule config, you can simply pass it as an init parameter:
951
+
952
+ ```python
953
+ # ./src/training_pipeline.py
954
+ model = hydra.utils.instantiate(config.model, dm_conf=config.datamodule, _recursive_=False)
955
+ ```
956
+
957
+ Now you can access any datamodule config part like this:
958
+
959
+ ```python
960
+ # ./src/models/my_model.py
961
+ class MyLitModel(LightningModule):
962
+ def __init__(self, dm_conf, param1, param2):
963
+ super().__init__()
964
+
965
+ batch_size = dm_conf.batch_size
966
+ ```
967
+
968
+ 3. If you need to access the datamodule object attributes, a little hacky solution is to add Omegaconf resolver to your datamodule:
969
+
970
+ ```python
971
+ # ./src/datamodules/my_datamodule.py
972
+ from omegaconf import OmegaConf
973
+
974
+ class MyDataModule(LightningDataModule):
975
+ def __init__(self, param1, param2):
976
+ super().__init__()
977
+
978
+ self.param1 = param1
979
+
980
+ resolver_name = "datamodule"
981
+ OmegaConf.register_new_resolver(
982
+ resolver_name,
983
+ lambda name: getattr(self, name),
984
+ use_cache=False
985
+ )
986
+ ```
987
+
988
+ This way you can reference any datamodule attribute from your config like this:
989
+
990
+ ```yaml
991
+ # this will return attribute 'param1' from datamodule object
992
+ param1: ${datamodule: param1}
993
+ ```
994
+
995
+ When later accessing this field, say in your lightning model, it will get automatically resolved based on all resolvers that are registered. Remember not to access this field before datamodule is initialized or it will crash. **You also need to set `resolve=False` in `print_config()` in [train.py](train.py) or it will throw errors:**
996
+
997
+ ```python
998
+ # ./src/train.py
999
+ utils.print_config(config, resolve=False)
1000
+ ```
1001
+
1002
+ </details>
1003
+
1004
+ <details>
1005
+ <summary><b>Automatic activation of virtual environment and tab completion when entering folder</b></summary>
1006
+
1007
+ 1. Create a new file called `.autoenv` (this name is excluded from version control in `.gitignore`). <br>
1008
+ You can use it to automatically execute shell commands when entering folder. Add some commands to your `.autoenv` file, like in the example below:
1009
+
1010
+ ```bash
1011
+ # activate conda environment
1012
+ conda activate myenv
1013
+
1014
+ # activate hydra tab completion for bash
1015
+ eval "$(python train.py -sc install=bash)"
1016
+ ```
1017
+
1018
+ (these commands will be executed whenever you're openning or switching terminal to folder containing `.autoenv` file)
1019
+
1020
+ 2. To setup this automation for bash, execute the following line (it will append your `.bashrc` file):
1021
+
1022
+ ```bash
1023
+ echo "autoenv() { if [ -x .autoenv ]; then source .autoenv ; echo '.autoenv executed' ; fi } ; cd() { builtin cd \"\$@\" ; autoenv ; } ; autoenv" >> ~/.bashrc
1024
+ ```
1025
+
1026
+ 3. Lastly add execution previliges to your `.autoenv` file:
1027
+
1028
+ ```
1029
+ chmod +x .autoenv
1030
+ ```
1031
+
1032
+ (for safety, only `.autoenv` with previligies will be executed)
1033
+
1034
+ **Explanation**
1035
+
1036
+ The mentioned line appends your `.bashrc` file with 2 commands:
1037
+
1038
+ 1. `autoenv() { if [ -x .autoenv ]; then source .autoenv ; echo '.autoenv executed' ; fi }` - this declares the `autoenv()` function, which executes `.autoenv` file if it exists in current work dir and has execution previligies
1039
+ 2. `cd() { builtin cd \"\$@\" ; autoenv ; } ; autoenv` - this extends behaviour of `cd` command, to make it execute `autoenv()` function each time you change folder in terminal or open new terminal
1040
+
1041
+ </details>
1042
+
1043
+ <!--
1044
+ <details>
1045
+ <summary><b>Making sweeps failure resistant</b></summary>
1046
+
1047
+ TODO
1048
+
1049
+ </details>
1050
+ -->
1051
+
1052
+ <br>
1053
+
1054
+ ## Best Practices
1055
+
1056
+ <details>
1057
+ <summary><b>Use Miniconda for GPU environments</b></summary>
1058
+
1059
+ Use miniconda for your python environments (it's usually unnecessary to install full anaconda environment, miniconda should be enough).
1060
+ It makes it easier to install some dependencies, like cudatoolkit for GPU support. It also allows you to acccess your environments globally.
1061
+
1062
+ Example installation:
1063
+
1064
+ ```bash
1065
+ wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
1066
+ bash Miniconda3-latest-Linux-x86_64.sh
1067
+ ```
1068
+
1069
+ Create new conda environment:
1070
+
1071
+ ```bash
1072
+ conda create -n myenv python=3.8
1073
+ conda activate myenv
1074
+ ```
1075
+
1076
+ </details>
1077
+
1078
+ <details>
1079
+ <summary><b>Use automatic code formatting</b></summary>
1080
+
1081
+ Use pre-commit hooks to standardize code formatting of your project and save mental energy.<br>
1082
+ Simply install pre-commit package with:
1083
+
1084
+ ```bash
1085
+ pip install pre-commit
1086
+ ```
1087
+
1088
+ Next, install hooks from [.pre-commit-config.yaml](.pre-commit-config.yaml):
1089
+
1090
+ ```bash
1091
+ pre-commit install
1092
+ ```
1093
+
1094
+ After that your code will be automatically reformatted on every new commit.<br>
1095
+ Currently template contains configurations of **black** (python code formatting), **isort** (python import sorting), **flake8** (python code analysis), **prettier** (yaml formating) and **nbstripout** (clearing output from jupyter notebooks). <br>
1096
+
1097
+ To reformat all files in the project use command:
1098
+
1099
+ ```bash
1100
+ pre-commit run -a
1101
+ ```
1102
+
1103
+ </details>
1104
+
1105
+ <details>
1106
+ <summary><b>Set private environment variables in .env file</b></summary>
1107
+
1108
+ System specific variables (e.g. absolute paths to datasets) should not be under version control or it will result in conflict between different users. Your private keys also shouldn't be versioned since you don't want them to be leaked.<br>
1109
+
1110
+ Template contains `.env.example` file, which serves as an example. Create a new file called `.env` (this name is excluded from version control in .gitignore).
1111
+ You should use it for storing environment variables like this:
1112
+
1113
+ ```
1114
+ MY_VAR=/home/user/my_system_path
1115
+ ```
1116
+
1117
+ All variables from `.env` are loaded in `train.py` automatically.
1118
+
1119
+ Hydra allows you to reference any env variable in `.yaml` configs like this:
1120
+
1121
+ ```yaml
1122
+ path_to_data: ${oc.env:MY_VAR}
1123
+ ```
1124
+
1125
+ </details>
1126
+
1127
+ <details>
1128
+ <summary><b>Name metrics using '/' character</b></summary>
1129
+
1130
+ Depending on which logger you're using, it's often useful to define metric name with `/` character:
1131
+
1132
+ ```python
1133
+ self.log("train/loss", loss)
1134
+ ```
1135
+
1136
+ This way loggers will treat your metrics as belonging to different sections, which helps to get them organised in UI.
1137
+
1138
+ </details>
1139
+
1140
+ <details>
1141
+ <summary><b>Use torchmetrics</b></summary>
1142
+
1143
+ Use official [torchmetrics](https://github.com/PytorchLightning/metrics) library to ensure proper calculation of metrics. This is especially important for multi-GPU training!
1144
+
1145
+ For example, instead of calculating accuracy by yourself, you should use the provided `Accuracy` class like this:
1146
+
1147
+ ```python
1148
+ from torchmetrics.classification.accuracy import Accuracy
1149
+
1150
+
1151
+ class LitModel(LightningModule):
1152
+ def __init__(self)
1153
+ self.train_acc = Accuracy()
1154
+ self.val_acc = Accuracy()
1155
+
1156
+ def training_step(self, batch, batch_idx):
1157
+ ...
1158
+ acc = self.train_acc(predictions, targets)
1159
+ self.log("train/acc", acc)
1160
+ ...
1161
+
1162
+ def validation_step(self, batch, batch_idx):
1163
+ ...
1164
+ acc = self.val_acc(predictions, targets)
1165
+ self.log("val/acc", acc)
1166
+ ...
1167
+ ```
1168
+
1169
+ Make sure to use different metric instance for each step to ensure proper value reduction over all GPU processes.
1170
+
1171
+ Torchmetrics provides metrics for most use cases, like F1 score or confusion matrix. Read [documentation](https://torchmetrics.readthedocs.io/en/latest/#more-reading) for more.
1172
+
1173
+ </details>
1174
+
1175
+ <details>
1176
+ <summary><b>Follow PyTorch Lightning style guide</b></summary>
1177
+
1178
+ The style guide is available [here](https://pytorch-lightning.readthedocs.io/en/latest/starter/style_guide.html).<br>
1179
+
1180
+ 1. Be explicit in your init. Try to define all the relevant defaults so that the user doesn’t have to guess. Provide type hints. This way your module is reusable across projects!
1181
+
1182
+ ```python
1183
+ class LitModel(LightningModule):
1184
+ def __init__(self, layer_size: int = 256, lr: float = 0.001):
1185
+ ```
1186
+
1187
+ 2. Preserve the recommended method order.
1188
+
1189
+ ```python
1190
+ class LitModel(LightningModule):
1191
+
1192
+ def __init__():
1193
+ ...
1194
+
1195
+ def forward():
1196
+ ...
1197
+
1198
+ def training_step():
1199
+ ...
1200
+
1201
+ def training_step_end():
1202
+ ...
1203
+
1204
+ def training_epoch_end():
1205
+ ...
1206
+
1207
+ def validation_step():
1208
+ ...
1209
+
1210
+ def validation_step_end():
1211
+ ...
1212
+
1213
+ def validation_epoch_end():
1214
+ ...
1215
+
1216
+ def test_step():
1217
+ ...
1218
+
1219
+ def test_step_end():
1220
+ ...
1221
+
1222
+ def test_epoch_end():
1223
+ ...
1224
+
1225
+ def configure_optimizers():
1226
+ ...
1227
+
1228
+ def any_extra_hook():
1229
+ ...
1230
+ ```
1231
+
1232
+ </details>
1233
+
1234
+ <details>
1235
+ <summary><b>Version control your data and models with DVC</b></summary>
1236
+
1237
+ Use [DVC](https://dvc.org) to version control big files, like your data or trained ML models.<br>
1238
+ To initialize the dvc repository:
1239
+
1240
+ ```bash
1241
+ dvc init
1242
+ ```
1243
+
1244
+ To start tracking a file or directory, use `dvc add`:
1245
+
1246
+ ```bash
1247
+ dvc add data/MNIST
1248
+ ```
1249
+
1250
+ DVC stores information about the added file (or a directory) in a special .dvc file named data/MNIST.dvc, a small text file with a human-readable format. This file can be easily versioned like source code with Git, as a placeholder for the original data:
1251
+
1252
+ ```bash
1253
+ git add data/MNIST.dvc data/.gitignore
1254
+ git commit -m "Add raw data"
1255
+ ```
1256
+
1257
+ </details>
1258
+
1259
+ <details>
1260
+ <summary><b>Support installing project as a package</b></summary>
1261
+
1262
+ It allows other people to easily use your modules in their own projects.
1263
+ Change name of the `src` folder to your project name and add `setup.py` file:
1264
+
1265
+ ```python
1266
+ from setuptools import find_packages, setup
1267
+
1268
+
1269
+ setup(
1270
+ name="src", # change "src" folder name to your project name
1271
+ version="0.0.0",
1272
+ description="Describe Your Cool Project",
1273
+ author="...",
1274
+ author_email="...",
1275
+ url="https://github.com/ashleve/lightning-hydra-template", # replace with your own github project link
1276
+ install_requires=[
1277
+ "pytorch>=1.10.0",
1278
+ "pytorch-lightning>=1.4.0",
1279
+ "hydra-core>=1.1.0",
1280
+ ],
1281
+ packages=find_packages(),
1282
+ )
1283
+ ```
1284
+
1285
+ Now your project can be installed from local files:
1286
+
1287
+ ```bash
1288
+ pip install -e .
1289
+ ```
1290
+
1291
+ Or directly from git repository:
1292
+
1293
+ ```bash
1294
+ pip install git+git://github.com/YourGithubName/your-repo-name.git --upgrade
1295
+ ```
1296
+
1297
+ So any file can be easily imported into any other file like so:
1298
+
1299
+ ```python
1300
+ from project_name.models.mnist_module import MNISTLitModule
1301
+ from project_name.datamodules.mnist_datamodule import MNISTDataModule
1302
+ ```
1303
+
1304
+ </details>
1305
+
1306
+ <!-- <details>
1307
+ <summary><b>Make notebooks independent from other files</b></summary>
1308
+
1309
+ It's a good practice for jupyter notebooks to be portable. Try to make them independent from src files. If you need to access external code, try to embed it inside the notebook.
1310
+
1311
+ </details> -->
1312
+
1313
+ <!--<details>
1314
+ <summary><b>Use Docker</b></summary>
1315
+
1316
+ Docker makes it easy to initialize the whole training environment, e.g. when you want to execute experiments in cloud or on some private computing cluster. You can extend [dockerfiles](https://github.com/ashleve/lightning-hydra-template/tree/dockerfiles) provided in the template with your own instructions for building the image.<br>
1317
+
1318
+ </details> -->
1319
+
1320
+ <br>
1321
+
1322
+ ## Other Repositories
1323
+
1324
+ <details>
1325
+ <summary><b>Inspirations</b></summary>
1326
+
1327
+ This template was inspired by:
1328
+ [PyTorchLightning/deep-learninig-project-template](https://github.com/PyTorchLightning/deep-learning-project-template),
1329
+ [drivendata/cookiecutter-data-science](https://github.com/drivendata/cookiecutter-data-science),
1330
+ [tchaton/lightning-hydra-seed](https://github.com/tchaton/lightning-hydra-seed),
1331
+ [Erlemar/pytorch_tempest](https://github.com/Erlemar/pytorch_tempest),
1332
+ [lucmos/nn-template](https://github.com/lucmos/nn-template).
1333
+
1334
+ </details>
1335
+
1336
+ <details>
1337
+ <summary><b>Useful repositories</b></summary>
1338
+
1339
+ - [pytorch/hydra-torch](https://github.com/pytorch/hydra-torch) - resources for configuring PyTorch classes with Hydra,
1340
+ - [romesco/hydra-lightning](https://github.com/romesco/hydra-lightning) - resources for configuring PyTorch Lightning classes with Hydra
1341
+ - [lucmos/nn-template](https://github.com/lucmos/nn-template) - similar template
1342
+ - [PyTorchLightning/lightning-transformers](https://github.com/PyTorchLightning/lightning-transformers) - official Lightning Transformers repo built with Hydra
1343
+
1344
+ </details>
1345
+
1346
+ <!-- ## :star:&nbsp; Stargazers Over Time
1347
+ [![Stargazers over time](https://starchart.cc/ashleve/lightning-hydra-template.svg)](https://starchart.cc/ashleve/lightning-hydra-template) -->
1348
+
1349
+ <br>
1350
+
1351
+ ## License
1352
+
1353
+ This project is licensed under the MIT License.
1354
+
1355
+ ```
1356
+ MIT License
1357
+
1358
+ Copyright (c) 2021 ashleve
1359
+
1360
+ Permission is hereby granted, free of charge, to any person obtaining a copy
1361
+ of this software and associated documentation files (the "Software"), to deal
1362
+ in the Software without restriction, including without limitation the rights
1363
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
1364
+ copies of the Software, and to permit persons to whom the Software is
1365
+ furnished to do so, subject to the following conditions:
1366
+
1367
+ The above copyright notice and this permission notice shall be included in all
1368
+ copies or substantial portions of the Software.
1369
+
1370
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
1371
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
1372
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
1373
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
1374
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
1375
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
1376
+ SOFTWARE.
1377
+ ```
1378
+
1379
+ <br>
1380
+ <br>
1381
+ <br>
1382
+ <br>
1383
+
1384
+ **DELETE EVERYTHING ABOVE FOR YOUR PROJECT**
1385
+
1386
+ ---
1387
+
1388
+ <div align="center">
1389
+
1390
+ # Your Project Name
1391
+
1392
+ <a href="https://pytorch.org/get-started/locally/"><img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-ee4c2c?logo=pytorch&logoColor=white"></a>
1393
+ <a href="https://pytorchlightning.ai/"><img alt="Lightning" src="https://img.shields.io/badge/-Lightning-792ee5?logo=pytorchlightning&logoColor=white"></a>
1394
+ <a href="https://hydra.cc/"><img alt="Config: Hydra" src="https://img.shields.io/badge/Config-Hydra-89b8cd"></a>
1395
+ <a href="https://github.com/ashleve/lightning-hydra-template"><img alt="Template" src="https://img.shields.io/badge/-Lightning--Hydra--Template-017F2F?style=flat&logo=github&labelColor=gray"></a><br>
1396
+ [![Paper](http://img.shields.io/badge/paper-arxiv.1001.2234-B31B1B.svg)](https://www.nature.com/articles/nature14539)
1397
+ [![Conference](http://img.shields.io/badge/AnyConference-year-4b44ce.svg)](https://papers.nips.cc/paper/2020)
1398
+
1399
+ </div>
1400
+
1401
+ ## Description
1402
+
1403
+ What it does
1404
+
1405
+ ## How to run
1406
+
1407
+ Install dependencies
1408
+
1409
+ ```bash
1410
+ # clone project
1411
+ git clone https://github.com/YourGithubName/your-repo-name
1412
+ cd your-repo-name
1413
+
1414
+ # [OPTIONAL] create conda environment
1415
+ conda create -n myenv python=3.8
1416
+ conda activate myenv
1417
+
1418
+ # install pytorch according to instructions
1419
+ # https://pytorch.org/get-started/
1420
+
1421
+ # install requirements
1422
+ pip install -r requirements.txt
1423
+ ```
1424
+
1425
+ Train model with default configuration
1426
+
1427
+ ```bash
1428
+ # train on CPU
1429
+ python train.py trainer.gpus=0
1430
+
1431
+ # train on GPU
1432
+ python train.py trainer.gpus=1
1433
+ ```
1434
+
1435
+ Train model with chosen experiment configuration from [configs/experiment/](configs/experiment/)
1436
+
1437
+ ```bash
1438
+ python train.py experiment=experiment_name.yaml
1439
+ ```
1440
+
1441
+ You can override any parameter from command line like this
1442
+
1443
+ ```bash
1444
+ python train.py trainer.max_epochs=20 datamodule.batch_size=64
1445
+ ```
models/configs/callbacks/default.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_checkpoint:
2
+ _target_: pytorch_lightning.callbacks.ModelCheckpoint
3
+ monitor: "val/acc" # name of the logged metric which determines when model is improving
4
+ mode: "max" # "max" means higher metric value is better, can be also "min"
5
+ save_top_k: 1 # save k best models (determined by above metric)
6
+ save_last: True # additionaly always save model from last epoch
7
+ verbose: False
8
+ dirpath: "checkpoints/"
9
+ filename: "epoch_{epoch:03d}"
10
+ auto_insert_metric_name: False
11
+
12
+ early_stopping:
13
+ _target_: pytorch_lightning.callbacks.EarlyStopping
14
+ monitor: "val/acc" # name of the logged metric which determines when model is improving
15
+ mode: "max" # "max" means higher metric value is better, can be also "min"
16
+ patience: 100 # how many validation epochs of not improving until training stops
17
+ min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement
18
+
19
+ model_summary:
20
+ _target_: pytorch_lightning.callbacks.RichModelSummary
21
+ max_depth: -1
22
+
23
+ rich_progress_bar:
24
+ _target_: pytorch_lightning.callbacks.RichProgressBar
models/configs/callbacks/none.yaml ADDED
File without changes
models/configs/datamodule/mnist.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ _target_: src.datamodules.mnist_datamodule.MNISTDataModule
2
+
3
+ data_dir: ${data_dir} # data_dir is specified in config.yaml
4
+ batch_size: 64
5
+ train_val_test_split: [55_000, 5_000, 10_000]
6
+ num_workers: 0
7
+ pin_memory: False
models/configs/debug/default.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # default debugging setup, runs 1 full epoch
4
+ # other debugging configs can inherit from this one
5
+
6
+ defaults:
7
+ - override /log_dir: debug.yaml
8
+
9
+ trainer:
10
+ max_epochs: 1
11
+ gpus: 0 # debuggers don't like gpus
12
+ detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor
13
+ track_grad_norm: 2 # track gradient norm with loggers
14
+
15
+ datamodule:
16
+ num_workers: 0 # debuggers don't like multiprocessing
17
+ pin_memory: False # disable gpu memory pin
18
+
19
+ # sets level of all command line loggers to 'DEBUG'
20
+ # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/
21
+ hydra:
22
+ verbose: True
23
+
24
+ # use this to set level of only chosen command line loggers to 'DEBUG':
25
+ # verbose: [src.train, src.utils]
26
+
27
+ # config is already printed by hydra when `hydra/verbose: True`
28
+ print_config: False
models/configs/debug/limit_batches.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # uses only 1% of the training data and 5% of validation/test data
4
+
5
+ defaults:
6
+ - default.yaml
7
+
8
+ trainer:
9
+ max_epochs: 3
10
+ limit_train_batches: 0.01
11
+ limit_val_batches: 0.05
12
+ limit_test_batches: 0.05
models/configs/debug/overfit.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # overfits to 3 batches
4
+
5
+ defaults:
6
+ - default.yaml
7
+
8
+ trainer:
9
+ max_epochs: 20
10
+ overfit_batches: 3
models/configs/debug/profiler.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # runs with execution time profiling
4
+
5
+ defaults:
6
+ - default.yaml
7
+
8
+ trainer:
9
+ max_epochs: 1
10
+ profiler: "simple"
11
+ # profiler: "advanced"
12
+ # profiler: "pytorch"
models/configs/debug/step.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # runs 1 train, 1 validation and 1 test step
4
+
5
+ defaults:
6
+ - default.yaml
7
+
8
+ trainer:
9
+ fast_dev_run: true
models/configs/debug/test_only.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # runs only test epoch
4
+
5
+ defaults:
6
+ - default.yaml
7
+
8
+ train: False
9
+ test: True
models/configs/experiment/example.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # to execute this experiment run:
4
+ # python train.py experiment=example
5
+
6
+ defaults:
7
+ - override /datamodule: mnist.yaml
8
+ - override /model: mnist.yaml
9
+ - override /callbacks: default.yaml
10
+ - override /logger: null
11
+ - override /trainer: default.yaml
12
+
13
+ # all parameters below will be merged with parameters from default configurations set above
14
+ # this allows you to overwrite only specified parameters
15
+
16
+ # name of the run determines folder name in logs
17
+ name: "simple_dense_net"
18
+
19
+ seed: 12345
20
+
21
+ trainer:
22
+ min_epochs: 10
23
+ max_epochs: 10
24
+ gradient_clip_val: 0.5
25
+
26
+ model:
27
+ lin1_size: 128
28
+ lin2_size: 256
29
+ lin3_size: 64
30
+ lr: 0.002
31
+
32
+ datamodule:
33
+ batch_size: 64
34
+
35
+ logger:
36
+ wandb:
37
+ tags: ["mnist", "${name}"]
models/configs/hparams_search/mnist_optuna.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # example hyperparameter optimization of some experiment with Optuna:
4
+ # python train.py -m hparams_search=mnist_optuna experiment=example
5
+
6
+ defaults:
7
+ - override /hydra/sweeper: optuna
8
+
9
+ # choose metric which will be optimized by Optuna
10
+ # make sure this is the correct name of some metric logged in lightning module!
11
+ optimized_metric: "val/acc_best"
12
+
13
+ # here we define Optuna hyperparameter search
14
+ # it optimizes for value returned from function with @hydra.main decorator
15
+ # docs: https://hydra.cc/docs/next/plugins/optuna_sweeper
16
+ hydra:
17
+ sweeper:
18
+ _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
19
+
20
+ # storage URL to persist optimization results
21
+ # for example, you can use SQLite if you set 'sqlite:///example.db'
22
+ storage: null
23
+
24
+ # name of the study to persist optimization results
25
+ study_name: null
26
+
27
+ # number of parallel workers
28
+ n_jobs: 1
29
+
30
+ # 'minimize' or 'maximize' the objective
31
+ direction: maximize
32
+
33
+ # total number of runs that will be executed
34
+ n_trials: 25
35
+
36
+ # choose Optuna hyperparameter sampler
37
+ # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html
38
+ sampler:
39
+ _target_: optuna.samplers.TPESampler
40
+ seed: 12345
41
+ n_startup_trials: 10 # number of random sampling runs before optimization starts
42
+
43
+ # define range of hyperparameters
44
+ search_space:
45
+ datamodule.batch_size:
46
+ type: categorical
47
+ choices: [32, 64, 128]
48
+ model.lr:
49
+ type: float
50
+ low: 0.0001
51
+ high: 0.2
52
+ model.lin1_size:
53
+ type: categorical
54
+ choices: [32, 64, 128, 256, 512]
55
+ model.lin2_size:
56
+ type: categorical
57
+ choices: [32, 64, 128, 256, 512]
58
+ model.lin3_size:
59
+ type: categorical
60
+ choices: [32, 64, 128, 256, 512]
models/configs/local/.gitkeep ADDED
File without changes
models/configs/log_dir/debug.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ hydra:
4
+ run:
5
+ dir: logs/debugs/runs/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}
6
+ sweep:
7
+ dir: logs/debugs/multiruns/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}
8
+ subdir: ${hydra.job.num}
models/configs/log_dir/default.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ hydra:
4
+ run:
5
+ dir: logs/experiments/runs/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}
6
+ sweep:
7
+ dir: logs/experiments/multiruns/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}
8
+ subdir: ${hydra.job.num}
models/configs/log_dir/evaluation.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ hydra:
4
+ run:
5
+ dir: logs/evaluations/runs/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}
6
+ sweep:
7
+ dir: logs/evaluations/multiruns/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}
8
+ subdir: ${hydra.job.num}
models/configs/logger/comet.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # https://www.comet.ml
2
+
3
+ comet:
4
+ _target_: pytorch_lightning.loggers.comet.CometLogger
5
+ api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable
6
+ project_name: "template-tests"
7
+ experiment_name: ${name}
models/configs/logger/csv.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # csv logger built in lightning
2
+
3
+ csv:
4
+ _target_: pytorch_lightning.loggers.csv_logs.CSVLogger
5
+ save_dir: "."
6
+ name: "csv/"
7
+ prefix: ""
models/configs/logger/many_loggers.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # train with many loggers at once
2
+
3
+ defaults:
4
+ # - comet.yaml
5
+ - csv.yaml
6
+ # - mlflow.yaml
7
+ # - neptune.yaml
8
+ - tensorboard.yaml
9
+ - wandb.yaml
models/configs/logger/mlflow.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://mlflow.org
2
+
3
+ mlflow:
4
+ _target_: pytorch_lightning.loggers.mlflow.MLFlowLogger
5
+ experiment_name: ${name}
6
+ tracking_uri: null
7
+ tags: null
8
+ save_dir: ./mlruns
9
+ prefix: ""
10
+ artifact_location: null