Hannes Kuchelmeister
commited on
Commit
·
d2e7940
1
Parent(s):
554c212
remove old template and use https://github.com/ashleve/lightning-hydra-template instead
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- model/.dockerignore +0 -4
- model/Dockerfile +0 -13
- model/LICENSE +0 -21
- model/README.md +0 -378
- model/base/__init__.py +0 -3
- model/base/base_data_loader.py +0 -61
- model/base/base_model.py +0 -25
- model/base/base_trainer.py +0 -151
- model/config.json +0 -50
- model/data_loader/data_loaders.py +0 -16
- model/docker-compose.yml +0 -12
- model/logger/__init__.py +0 -2
- model/logger/logger.py +0 -22
- model/logger/logger_config.json +0 -32
- model/logger/visualization.py +0 -73
- model/model/loss.py +0 -5
- model/model/metric.py +0 -20
- model/model/model.py +0 -22
- model/new_project.py +0 -18
- model/parse_config.py +0 -157
- model/requirements.txt +0 -6
- model/test.py +0 -81
- model/train.py +0 -73
- model/trainer/__init__.py +0 -1
- model/trainer/trainer.py +0 -110
- model/utils/__init__.py +0 -1
- model/utils/util.py +0 -67
- models/.env.example +7 -0
- models/.gitignore +148 -0
- models/.pre-commit-config.yaml +44 -0
- models/README.md +1445 -0
- models/configs/callbacks/default.yaml +24 -0
- models/configs/callbacks/none.yaml +0 -0
- models/configs/datamodule/mnist.yaml +7 -0
- models/configs/debug/default.yaml +28 -0
- models/configs/debug/limit_batches.yaml +12 -0
- models/configs/debug/overfit.yaml +10 -0
- models/configs/debug/profiler.yaml +12 -0
- models/configs/debug/step.yaml +9 -0
- models/configs/debug/test_only.yaml +9 -0
- models/configs/experiment/example.yaml +37 -0
- models/configs/hparams_search/mnist_optuna.yaml +60 -0
- models/configs/local/.gitkeep +0 -0
- models/configs/log_dir/debug.yaml +8 -0
- models/configs/log_dir/default.yaml +8 -0
- models/configs/log_dir/evaluation.yaml +8 -0
- models/configs/logger/comet.yaml +7 -0
- models/configs/logger/csv.yaml +7 -0
- models/configs/logger/many_loggers.yaml +9 -0
- models/configs/logger/mlflow.yaml +10 -0
model/.dockerignore
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
.env
|
2 |
-
in/
|
3 |
-
out/
|
4 |
-
new_proect.py
|
|
|
|
|
|
|
|
|
|
model/Dockerfile
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
FROM python:3.7
|
2 |
-
|
3 |
-
WORKDIR /usr/src/app
|
4 |
-
|
5 |
-
RUN apt-get update
|
6 |
-
RUN apt-get install libgl1 -y
|
7 |
-
|
8 |
-
COPY requirements.txt ./
|
9 |
-
RUN pip install --no-cache-dir -r requirements.txt
|
10 |
-
|
11 |
-
COPY . .
|
12 |
-
|
13 |
-
CMD sh -c "python train.py --config config.json"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/LICENSE
DELETED
@@ -1,21 +0,0 @@
|
|
1 |
-
MIT License
|
2 |
-
|
3 |
-
Copyright (c) 2018 Victor Huang
|
4 |
-
|
5 |
-
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
-
of this software and associated documentation files (the "Software"), to deal
|
7 |
-
in the Software without restriction, including without limitation the rights
|
8 |
-
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
-
copies of the Software, and to permit persons to whom the Software is
|
10 |
-
furnished to do so, subject to the following conditions:
|
11 |
-
|
12 |
-
The above copyright notice and this permission notice shall be included in all
|
13 |
-
copies or substantial portions of the Software.
|
14 |
-
|
15 |
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
-
SOFTWARE.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/README.md
DELETED
@@ -1,378 +0,0 @@
|
|
1 |
-
# PyTorch Template Project
|
2 |
-
PyTorch deep learning project made easy.
|
3 |
-
|
4 |
-
<!-- @import "[TOC]" {cmd="toc" depthFrom=1 depthTo=6 orderedList=false} -->
|
5 |
-
|
6 |
-
<!-- code_chunk_output -->
|
7 |
-
|
8 |
-
* [PyTorch Template Project](#pytorch-template-project)
|
9 |
-
* [Requirements](#requirements)
|
10 |
-
* [Features](#features)
|
11 |
-
* [Folder Structure](#folder-structure)
|
12 |
-
* [Usage](#usage)
|
13 |
-
* [Config file format](#config-file-format)
|
14 |
-
* [Using config files](#using-config-files)
|
15 |
-
* [Resuming from checkpoints](#resuming-from-checkpoints)
|
16 |
-
* [Using Multiple GPU](#using-multiple-gpu)
|
17 |
-
* [Customization](#customization)
|
18 |
-
* [Custom CLI options](#custom-cli-options)
|
19 |
-
* [Data Loader](#data-loader)
|
20 |
-
* [Trainer](#trainer)
|
21 |
-
* [Model](#model)
|
22 |
-
* [Loss](#loss)
|
23 |
-
* [metrics](#metrics)
|
24 |
-
* [Additional logging](#additional-logging)
|
25 |
-
* [Validation data](#validation-data)
|
26 |
-
* [Checkpoints](#checkpoints)
|
27 |
-
* [Tensorboard Visualization](#tensorboard-visualization)
|
28 |
-
* [Contribution](#contribution)
|
29 |
-
* [TODOs](#todos)
|
30 |
-
* [License](#license)
|
31 |
-
* [Acknowledgements](#acknowledgements)
|
32 |
-
|
33 |
-
<!-- /code_chunk_output -->
|
34 |
-
|
35 |
-
## Requirements
|
36 |
-
* Python >= 3.5 (3.6 recommended)
|
37 |
-
* PyTorch >= 0.4 (1.2 recommended)
|
38 |
-
* tqdm (Optional for `test.py`)
|
39 |
-
* tensorboard >= 1.14 (see [Tensorboard Visualization](#tensorboard-visualization))
|
40 |
-
|
41 |
-
## Features
|
42 |
-
* Clear folder structure which is suitable for many deep learning projects.
|
43 |
-
* `.json` config file support for convenient parameter tuning.
|
44 |
-
* Customizable command line options for more convenient parameter tuning.
|
45 |
-
* Checkpoint saving and resuming.
|
46 |
-
* Abstract base classes for faster development:
|
47 |
-
* `BaseTrainer` handles checkpoint saving/resuming, training process logging, and more.
|
48 |
-
* `BaseDataLoader` handles batch generation, data shuffling, and validation data splitting.
|
49 |
-
* `BaseModel` provides basic model summary.
|
50 |
-
|
51 |
-
## Folder Structure
|
52 |
-
```
|
53 |
-
pytorch-template/
|
54 |
-
│
|
55 |
-
├── train.py - main script to start training
|
56 |
-
├── test.py - evaluation of trained model
|
57 |
-
│
|
58 |
-
├── config.json - holds configuration for training
|
59 |
-
├── parse_config.py - class to handle config file and cli options
|
60 |
-
│
|
61 |
-
├── new_project.py - initialize new project with template files
|
62 |
-
│
|
63 |
-
├── base/ - abstract base classes
|
64 |
-
│ ├── base_data_loader.py
|
65 |
-
│ ├── base_model.py
|
66 |
-
│ └── base_trainer.py
|
67 |
-
│
|
68 |
-
├── data_loader/ - anything about data loading goes here
|
69 |
-
│ └── data_loaders.py
|
70 |
-
│
|
71 |
-
├── data/ - default directory for storing input data
|
72 |
-
│
|
73 |
-
├── model/ - models, losses, and metrics
|
74 |
-
│ ├── model.py
|
75 |
-
│ ├── metric.py
|
76 |
-
│ └── loss.py
|
77 |
-
│
|
78 |
-
├── saved/
|
79 |
-
│ ├── models/ - trained models are saved here
|
80 |
-
│ └── log/ - default logdir for tensorboard and logging output
|
81 |
-
│
|
82 |
-
├── trainer/ - trainers
|
83 |
-
│ └── trainer.py
|
84 |
-
│
|
85 |
-
├── logger/ - module for tensorboard visualization and logging
|
86 |
-
│ ├── visualization.py
|
87 |
-
│ ├── logger.py
|
88 |
-
│ └── logger_config.json
|
89 |
-
│
|
90 |
-
└── utils/ - small utility functions
|
91 |
-
├── util.py
|
92 |
-
└── ...
|
93 |
-
```
|
94 |
-
|
95 |
-
## Usage
|
96 |
-
The code in this repo is an MNIST example of the template.
|
97 |
-
Try `python train.py -c config.json` to run code.
|
98 |
-
|
99 |
-
### Config file format
|
100 |
-
Config files are in `.json` format:
|
101 |
-
```javascript
|
102 |
-
{
|
103 |
-
"name": "Mnist_LeNet", // training session name
|
104 |
-
"n_gpu": 1, // number of GPUs to use for training.
|
105 |
-
|
106 |
-
"arch": {
|
107 |
-
"type": "MnistModel", // name of model architecture to train
|
108 |
-
"args": {
|
109 |
-
|
110 |
-
}
|
111 |
-
},
|
112 |
-
"data_loader": {
|
113 |
-
"type": "MnistDataLoader", // selecting data loader
|
114 |
-
"args":{
|
115 |
-
"data_dir": "data/", // dataset path
|
116 |
-
"batch_size": 64, // batch size
|
117 |
-
"shuffle": true, // shuffle training data before splitting
|
118 |
-
"validation_split": 0.1 // size of validation dataset. float(portion) or int(number of samples)
|
119 |
-
"num_workers": 2, // number of cpu processes to be used for data loading
|
120 |
-
}
|
121 |
-
},
|
122 |
-
"optimizer": {
|
123 |
-
"type": "Adam",
|
124 |
-
"args":{
|
125 |
-
"lr": 0.001, // learning rate
|
126 |
-
"weight_decay": 0, // (optional) weight decay
|
127 |
-
"amsgrad": true
|
128 |
-
}
|
129 |
-
},
|
130 |
-
"loss": "nll_loss", // loss
|
131 |
-
"metrics": [
|
132 |
-
"accuracy", "top_k_acc" // list of metrics to evaluate
|
133 |
-
],
|
134 |
-
"lr_scheduler": {
|
135 |
-
"type": "StepLR", // learning rate scheduler
|
136 |
-
"args":{
|
137 |
-
"step_size": 50,
|
138 |
-
"gamma": 0.1
|
139 |
-
}
|
140 |
-
},
|
141 |
-
"trainer": {
|
142 |
-
"epochs": 100, // number of training epochs
|
143 |
-
"save_dir": "saved/", // checkpoints are saved in save_dir/models/name
|
144 |
-
"save_freq": 1, // save checkpoints every save_freq epochs
|
145 |
-
"verbosity": 2, // 0: quiet, 1: per epoch, 2: full
|
146 |
-
|
147 |
-
"monitor": "min val_loss" // mode and metric for model performance monitoring. set 'off' to disable.
|
148 |
-
"early_stop": 10 // number of epochs to wait before early stop. set 0 to disable.
|
149 |
-
|
150 |
-
"tensorboard": true, // enable tensorboard visualization
|
151 |
-
}
|
152 |
-
}
|
153 |
-
```
|
154 |
-
|
155 |
-
Add addional configurations if you need.
|
156 |
-
|
157 |
-
### Using config files
|
158 |
-
Modify the configurations in `.json` config files, then run:
|
159 |
-
|
160 |
-
```
|
161 |
-
python train.py --config config.json
|
162 |
-
```
|
163 |
-
|
164 |
-
### Resuming from checkpoints
|
165 |
-
You can resume from a previously saved checkpoint by:
|
166 |
-
|
167 |
-
```
|
168 |
-
python train.py --resume path/to/checkpoint
|
169 |
-
```
|
170 |
-
|
171 |
-
### Using Multiple GPU
|
172 |
-
You can enable multi-GPU training by setting `n_gpu` argument of the config file to larger number.
|
173 |
-
If configured to use smaller number of gpu than available, first n devices will be used by default.
|
174 |
-
Specify indices of available GPUs by cuda environmental variable.
|
175 |
-
```
|
176 |
-
python train.py --device 2,3 -c config.json
|
177 |
-
```
|
178 |
-
This is equivalent to
|
179 |
-
```
|
180 |
-
CUDA_VISIBLE_DEVICES=2,3 python train.py -c config.py
|
181 |
-
```
|
182 |
-
|
183 |
-
## Customization
|
184 |
-
|
185 |
-
### Project initialization
|
186 |
-
Use the `new_project.py` script to make your new project directory with template files.
|
187 |
-
`python new_project.py ../NewProject` then a new project folder named 'NewProject' will be made.
|
188 |
-
This script will filter out unneccessary files like cache, git files or readme file.
|
189 |
-
|
190 |
-
### Custom CLI options
|
191 |
-
|
192 |
-
Changing values of config file is a clean, safe and easy way of tuning hyperparameters. However, sometimes
|
193 |
-
it is better to have command line options if some values need to be changed too often or quickly.
|
194 |
-
|
195 |
-
This template uses the configurations stored in the json file by default, but by registering custom options as follows
|
196 |
-
you can change some of them using CLI flags.
|
197 |
-
|
198 |
-
```python
|
199 |
-
# simple class-like object having 3 attributes, `flags`, `type`, `target`.
|
200 |
-
CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
|
201 |
-
options = [
|
202 |
-
CustomArgs(['--lr', '--learning_rate'], type=float, target=('optimizer', 'args', 'lr')),
|
203 |
-
CustomArgs(['--bs', '--batch_size'], type=int, target=('data_loader', 'args', 'batch_size'))
|
204 |
-
# options added here can be modified by command line flags.
|
205 |
-
]
|
206 |
-
```
|
207 |
-
`target` argument should be sequence of keys, which are used to access that option in the config dict. In this example, `target`
|
208 |
-
for the learning rate option is `('optimizer', 'args', 'lr')` because `config['optimizer']['args']['lr']` points to the learning rate.
|
209 |
-
`python train.py -c config.json --bs 256` runs training with options given in `config.json` except for the `batch size`
|
210 |
-
which is increased to 256 by command line options.
|
211 |
-
|
212 |
-
|
213 |
-
### Data Loader
|
214 |
-
* **Writing your own data loader**
|
215 |
-
|
216 |
-
1. **Inherit ```BaseDataLoader```**
|
217 |
-
|
218 |
-
`BaseDataLoader` is a subclass of `torch.utils.data.DataLoader`, you can use either of them.
|
219 |
-
|
220 |
-
`BaseDataLoader` handles:
|
221 |
-
* Generating next batch
|
222 |
-
* Data shuffling
|
223 |
-
* Generating validation data loader by calling
|
224 |
-
`BaseDataLoader.split_validation()`
|
225 |
-
|
226 |
-
* **DataLoader Usage**
|
227 |
-
|
228 |
-
`BaseDataLoader` is an iterator, to iterate through batches:
|
229 |
-
```python
|
230 |
-
for batch_idx, (x_batch, y_batch) in data_loader:
|
231 |
-
pass
|
232 |
-
```
|
233 |
-
* **Example**
|
234 |
-
|
235 |
-
Please refer to `data_loader/data_loaders.py` for an MNIST data loading example.
|
236 |
-
|
237 |
-
### Trainer
|
238 |
-
* **Writing your own trainer**
|
239 |
-
|
240 |
-
1. **Inherit ```BaseTrainer```**
|
241 |
-
|
242 |
-
`BaseTrainer` handles:
|
243 |
-
* Training process logging
|
244 |
-
* Checkpoint saving
|
245 |
-
* Checkpoint resuming
|
246 |
-
* Reconfigurable performance monitoring for saving current best model, and early stop training.
|
247 |
-
* If config `monitor` is set to `max val_accuracy`, which means then the trainer will save a checkpoint `model_best.pth` when `validation accuracy` of epoch replaces current `maximum`.
|
248 |
-
* If config `early_stop` is set, training will be automatically terminated when model performance does not improve for given number of epochs. This feature can be turned off by passing 0 to the `early_stop` option, or just deleting the line of config.
|
249 |
-
|
250 |
-
2. **Implementing abstract methods**
|
251 |
-
|
252 |
-
You need to implement `_train_epoch()` for your training process, if you need validation then you can implement `_valid_epoch()` as in `trainer/trainer.py`
|
253 |
-
|
254 |
-
* **Example**
|
255 |
-
|
256 |
-
Please refer to `trainer/trainer.py` for MNIST training.
|
257 |
-
|
258 |
-
* **Iteration-based training**
|
259 |
-
|
260 |
-
`Trainer.__init__` takes an optional argument, `len_epoch` which controls number of batches(steps) in each epoch.
|
261 |
-
|
262 |
-
### Model
|
263 |
-
* **Writing your own model**
|
264 |
-
|
265 |
-
1. **Inherit `BaseModel`**
|
266 |
-
|
267 |
-
`BaseModel` handles:
|
268 |
-
* Inherited from `torch.nn.Module`
|
269 |
-
* `__str__`: Modify native `print` function to prints the number of trainable parameters.
|
270 |
-
|
271 |
-
2. **Implementing abstract methods**
|
272 |
-
|
273 |
-
Implement the foward pass method `forward()`
|
274 |
-
|
275 |
-
* **Example**
|
276 |
-
|
277 |
-
Please refer to `model/model.py` for a LeNet example.
|
278 |
-
|
279 |
-
### Loss
|
280 |
-
Custom loss functions can be implemented in 'model/loss.py'. Use them by changing the name given in "loss" in config file, to corresponding name.
|
281 |
-
|
282 |
-
### Metrics
|
283 |
-
Metric functions are located in 'model/metric.py'.
|
284 |
-
|
285 |
-
You can monitor multiple metrics by providing a list in the configuration file, e.g.:
|
286 |
-
```json
|
287 |
-
"metrics": ["accuracy", "top_k_acc"],
|
288 |
-
```
|
289 |
-
|
290 |
-
### Additional logging
|
291 |
-
If you have additional information to be logged, in `_train_epoch()` of your trainer class, merge them with `log` as shown below before returning:
|
292 |
-
|
293 |
-
```python
|
294 |
-
additional_log = {"gradient_norm": g, "sensitivity": s}
|
295 |
-
log.update(additional_log)
|
296 |
-
return log
|
297 |
-
```
|
298 |
-
|
299 |
-
### Testing
|
300 |
-
You can test trained model by running `test.py` passing path to the trained checkpoint by `--resume` argument.
|
301 |
-
|
302 |
-
### Validation data
|
303 |
-
To split validation data from a data loader, call `BaseDataLoader.split_validation()`, then it will return a data loader for validation of size specified in your config file.
|
304 |
-
The `validation_split` can be a ratio of validation set per total data(0.0 <= float < 1.0), or the number of samples (0 <= int < `n_total_samples`).
|
305 |
-
|
306 |
-
**Note**: the `split_validation()` method will modify the original data loader
|
307 |
-
**Note**: `split_validation()` will return `None` if `"validation_split"` is set to `0`
|
308 |
-
|
309 |
-
### Checkpoints
|
310 |
-
You can specify the name of the training session in config files:
|
311 |
-
```json
|
312 |
-
"name": "MNIST_LeNet",
|
313 |
-
```
|
314 |
-
|
315 |
-
The checkpoints will be saved in `save_dir/name/timestamp/checkpoint_epoch_n`, with timestamp in mmdd_HHMMSS format.
|
316 |
-
|
317 |
-
A copy of config file will be saved in the same folder.
|
318 |
-
|
319 |
-
**Note**: checkpoints contain:
|
320 |
-
```python
|
321 |
-
{
|
322 |
-
'arch': arch,
|
323 |
-
'epoch': epoch,
|
324 |
-
'state_dict': self.model.state_dict(),
|
325 |
-
'optimizer': self.optimizer.state_dict(),
|
326 |
-
'monitor_best': self.mnt_best,
|
327 |
-
'config': self.config
|
328 |
-
}
|
329 |
-
```
|
330 |
-
|
331 |
-
### Tensorboard Visualization
|
332 |
-
This template supports Tensorboard visualization by using either `torch.utils.tensorboard` or [TensorboardX](https://github.com/lanpa/tensorboardX).
|
333 |
-
|
334 |
-
1. **Install**
|
335 |
-
|
336 |
-
If you are using pytorch 1.1 or higher, install tensorboard by 'pip install tensorboard>=1.14.0'.
|
337 |
-
|
338 |
-
Otherwise, you should install tensorboardx. Follow installation guide in [TensorboardX](https://github.com/lanpa/tensorboardX).
|
339 |
-
|
340 |
-
2. **Run training**
|
341 |
-
|
342 |
-
Make sure that `tensorboard` option in the config file is turned on.
|
343 |
-
|
344 |
-
```
|
345 |
-
"tensorboard" : true
|
346 |
-
```
|
347 |
-
|
348 |
-
3. **Open Tensorboard server**
|
349 |
-
|
350 |
-
Type `tensorboard --logdir saved/log/` at the project root, then server will open at `http://localhost:6006`
|
351 |
-
|
352 |
-
By default, values of loss and metrics specified in config file, input images, and histogram of model parameters will be logged.
|
353 |
-
If you need more visualizations, use `add_scalar('tag', data)`, `add_image('tag', image)`, etc in the `trainer._train_epoch` method.
|
354 |
-
`add_something()` methods in this template are basically wrappers for those of `tensorboardX.SummaryWriter` and `torch.utils.tensorboard.SummaryWriter` modules.
|
355 |
-
|
356 |
-
**Note**: You don't have to specify current steps, since `WriterTensorboard` class defined at `logger/visualization.py` will track current steps.
|
357 |
-
|
358 |
-
## Contribution
|
359 |
-
Feel free to contribute any kind of function or enhancement, here the coding style follows PEP8
|
360 |
-
|
361 |
-
Code should pass the [Flake8](http://flake8.pycqa.org/en/latest/) check before committing.
|
362 |
-
|
363 |
-
## TODOs
|
364 |
-
|
365 |
-
- [ ] Multiple optimizers
|
366 |
-
- [ ] Support more tensorboard functions
|
367 |
-
- [x] Using fixed random seed
|
368 |
-
- [x] Support pytorch native tensorboard
|
369 |
-
- [x] `tensorboardX` logger support
|
370 |
-
- [x] Configurable logging layout, checkpoint naming
|
371 |
-
- [x] Iteration-based training (instead of epoch-based)
|
372 |
-
- [x] Adding command line option for fine-tuning
|
373 |
-
|
374 |
-
## License
|
375 |
-
This project is licensed under the MIT License. See LICENSE for more details
|
376 |
-
|
377 |
-
## Acknowledgements
|
378 |
-
This project is inspired by the project [Tensorflow-Project-Template](https://github.com/MrGemy95/Tensorflow-Project-Template) by [Mahmoud Gemy](https://github.com/MrGemy95)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/base/__init__.py
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
from .base_data_loader import *
|
2 |
-
from .base_model import *
|
3 |
-
from .base_trainer import *
|
|
|
|
|
|
|
|
model/base/base_data_loader.py
DELETED
@@ -1,61 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
from torch.utils.data import DataLoader
|
3 |
-
from torch.utils.data.dataloader import default_collate
|
4 |
-
from torch.utils.data.sampler import SubsetRandomSampler
|
5 |
-
|
6 |
-
|
7 |
-
class BaseDataLoader(DataLoader):
|
8 |
-
"""
|
9 |
-
Base class for all data loaders
|
10 |
-
"""
|
11 |
-
def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate):
|
12 |
-
self.validation_split = validation_split
|
13 |
-
self.shuffle = shuffle
|
14 |
-
|
15 |
-
self.batch_idx = 0
|
16 |
-
self.n_samples = len(dataset)
|
17 |
-
|
18 |
-
self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)
|
19 |
-
|
20 |
-
self.init_kwargs = {
|
21 |
-
'dataset': dataset,
|
22 |
-
'batch_size': batch_size,
|
23 |
-
'shuffle': self.shuffle,
|
24 |
-
'collate_fn': collate_fn,
|
25 |
-
'num_workers': num_workers
|
26 |
-
}
|
27 |
-
super().__init__(sampler=self.sampler, **self.init_kwargs)
|
28 |
-
|
29 |
-
def _split_sampler(self, split):
|
30 |
-
if split == 0.0:
|
31 |
-
return None, None
|
32 |
-
|
33 |
-
idx_full = np.arange(self.n_samples)
|
34 |
-
|
35 |
-
np.random.seed(0)
|
36 |
-
np.random.shuffle(idx_full)
|
37 |
-
|
38 |
-
if isinstance(split, int):
|
39 |
-
assert split > 0
|
40 |
-
assert split < self.n_samples, "validation set size is configured to be larger than entire dataset."
|
41 |
-
len_valid = split
|
42 |
-
else:
|
43 |
-
len_valid = int(self.n_samples * split)
|
44 |
-
|
45 |
-
valid_idx = idx_full[0:len_valid]
|
46 |
-
train_idx = np.delete(idx_full, np.arange(0, len_valid))
|
47 |
-
|
48 |
-
train_sampler = SubsetRandomSampler(train_idx)
|
49 |
-
valid_sampler = SubsetRandomSampler(valid_idx)
|
50 |
-
|
51 |
-
# turn off shuffle option which is mutually exclusive with sampler
|
52 |
-
self.shuffle = False
|
53 |
-
self.n_samples = len(train_idx)
|
54 |
-
|
55 |
-
return train_sampler, valid_sampler
|
56 |
-
|
57 |
-
def split_validation(self):
|
58 |
-
if self.valid_sampler is None:
|
59 |
-
return None
|
60 |
-
else:
|
61 |
-
return DataLoader(sampler=self.valid_sampler, **self.init_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/base/base_model.py
DELETED
@@ -1,25 +0,0 @@
|
|
1 |
-
import torch.nn as nn
|
2 |
-
import numpy as np
|
3 |
-
from abc import abstractmethod
|
4 |
-
|
5 |
-
|
6 |
-
class BaseModel(nn.Module):
|
7 |
-
"""
|
8 |
-
Base class for all models
|
9 |
-
"""
|
10 |
-
@abstractmethod
|
11 |
-
def forward(self, *inputs):
|
12 |
-
"""
|
13 |
-
Forward pass logic
|
14 |
-
|
15 |
-
:return: Model output
|
16 |
-
"""
|
17 |
-
raise NotImplementedError
|
18 |
-
|
19 |
-
def __str__(self):
|
20 |
-
"""
|
21 |
-
Model prints with number of trainable parameters
|
22 |
-
"""
|
23 |
-
model_parameters = filter(lambda p: p.requires_grad, self.parameters())
|
24 |
-
params = sum([np.prod(p.size()) for p in model_parameters])
|
25 |
-
return super().__str__() + '\nTrainable parameters: {}'.format(params)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/base/base_trainer.py
DELETED
@@ -1,151 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from abc import abstractmethod
|
3 |
-
from numpy import inf
|
4 |
-
from logger import TensorboardWriter
|
5 |
-
|
6 |
-
|
7 |
-
class BaseTrainer:
|
8 |
-
"""
|
9 |
-
Base class for all trainers
|
10 |
-
"""
|
11 |
-
def __init__(self, model, criterion, metric_ftns, optimizer, config):
|
12 |
-
self.config = config
|
13 |
-
self.logger = config.get_logger('trainer', config['trainer']['verbosity'])
|
14 |
-
|
15 |
-
self.model = model
|
16 |
-
self.criterion = criterion
|
17 |
-
self.metric_ftns = metric_ftns
|
18 |
-
self.optimizer = optimizer
|
19 |
-
|
20 |
-
cfg_trainer = config['trainer']
|
21 |
-
self.epochs = cfg_trainer['epochs']
|
22 |
-
self.save_period = cfg_trainer['save_period']
|
23 |
-
self.monitor = cfg_trainer.get('monitor', 'off')
|
24 |
-
|
25 |
-
# configuration to monitor model performance and save best
|
26 |
-
if self.monitor == 'off':
|
27 |
-
self.mnt_mode = 'off'
|
28 |
-
self.mnt_best = 0
|
29 |
-
else:
|
30 |
-
self.mnt_mode, self.mnt_metric = self.monitor.split()
|
31 |
-
assert self.mnt_mode in ['min', 'max']
|
32 |
-
|
33 |
-
self.mnt_best = inf if self.mnt_mode == 'min' else -inf
|
34 |
-
self.early_stop = cfg_trainer.get('early_stop', inf)
|
35 |
-
if self.early_stop <= 0:
|
36 |
-
self.early_stop = inf
|
37 |
-
|
38 |
-
self.start_epoch = 1
|
39 |
-
|
40 |
-
self.checkpoint_dir = config.save_dir
|
41 |
-
|
42 |
-
# setup visualization writer instance
|
43 |
-
self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard'])
|
44 |
-
|
45 |
-
if config.resume is not None:
|
46 |
-
self._resume_checkpoint(config.resume)
|
47 |
-
|
48 |
-
@abstractmethod
|
49 |
-
def _train_epoch(self, epoch):
|
50 |
-
"""
|
51 |
-
Training logic for an epoch
|
52 |
-
|
53 |
-
:param epoch: Current epoch number
|
54 |
-
"""
|
55 |
-
raise NotImplementedError
|
56 |
-
|
57 |
-
def train(self):
|
58 |
-
"""
|
59 |
-
Full training logic
|
60 |
-
"""
|
61 |
-
not_improved_count = 0
|
62 |
-
for epoch in range(self.start_epoch, self.epochs + 1):
|
63 |
-
result = self._train_epoch(epoch)
|
64 |
-
|
65 |
-
# save logged informations into log dict
|
66 |
-
log = {'epoch': epoch}
|
67 |
-
log.update(result)
|
68 |
-
|
69 |
-
# print logged informations to the screen
|
70 |
-
for key, value in log.items():
|
71 |
-
self.logger.info(' {:15s}: {}'.format(str(key), value))
|
72 |
-
|
73 |
-
# evaluate model performance according to configured metric, save best checkpoint as model_best
|
74 |
-
best = False
|
75 |
-
if self.mnt_mode != 'off':
|
76 |
-
try:
|
77 |
-
# check whether model performance improved or not, according to specified metric(mnt_metric)
|
78 |
-
improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \
|
79 |
-
(self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best)
|
80 |
-
except KeyError:
|
81 |
-
self.logger.warning("Warning: Metric '{}' is not found. "
|
82 |
-
"Model performance monitoring is disabled.".format(self.mnt_metric))
|
83 |
-
self.mnt_mode = 'off'
|
84 |
-
improved = False
|
85 |
-
|
86 |
-
if improved:
|
87 |
-
self.mnt_best = log[self.mnt_metric]
|
88 |
-
not_improved_count = 0
|
89 |
-
best = True
|
90 |
-
else:
|
91 |
-
not_improved_count += 1
|
92 |
-
|
93 |
-
if not_improved_count > self.early_stop:
|
94 |
-
self.logger.info("Validation performance didn\'t improve for {} epochs. "
|
95 |
-
"Training stops.".format(self.early_stop))
|
96 |
-
break
|
97 |
-
|
98 |
-
if epoch % self.save_period == 0:
|
99 |
-
self._save_checkpoint(epoch, save_best=best)
|
100 |
-
|
101 |
-
def _save_checkpoint(self, epoch, save_best=False):
|
102 |
-
"""
|
103 |
-
Saving checkpoints
|
104 |
-
|
105 |
-
:param epoch: current epoch number
|
106 |
-
:param log: logging information of the epoch
|
107 |
-
:param save_best: if True, rename the saved checkpoint to 'model_best.pth'
|
108 |
-
"""
|
109 |
-
arch = type(self.model).__name__
|
110 |
-
state = {
|
111 |
-
'arch': arch,
|
112 |
-
'epoch': epoch,
|
113 |
-
'state_dict': self.model.state_dict(),
|
114 |
-
'optimizer': self.optimizer.state_dict(),
|
115 |
-
'monitor_best': self.mnt_best,
|
116 |
-
'config': self.config
|
117 |
-
}
|
118 |
-
filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch))
|
119 |
-
torch.save(state, filename)
|
120 |
-
self.logger.info("Saving checkpoint: {} ...".format(filename))
|
121 |
-
if save_best:
|
122 |
-
best_path = str(self.checkpoint_dir / 'model_best.pth')
|
123 |
-
torch.save(state, best_path)
|
124 |
-
self.logger.info("Saving current best: model_best.pth ...")
|
125 |
-
|
126 |
-
def _resume_checkpoint(self, resume_path):
|
127 |
-
"""
|
128 |
-
Resume from saved checkpoints
|
129 |
-
|
130 |
-
:param resume_path: Checkpoint path to be resumed
|
131 |
-
"""
|
132 |
-
resume_path = str(resume_path)
|
133 |
-
self.logger.info("Loading checkpoint: {} ...".format(resume_path))
|
134 |
-
checkpoint = torch.load(resume_path)
|
135 |
-
self.start_epoch = checkpoint['epoch'] + 1
|
136 |
-
self.mnt_best = checkpoint['monitor_best']
|
137 |
-
|
138 |
-
# load architecture params from checkpoint.
|
139 |
-
if checkpoint['config']['arch'] != self.config['arch']:
|
140 |
-
self.logger.warning("Warning: Architecture configuration given in config file is different from that of "
|
141 |
-
"checkpoint. This may yield an exception while state_dict is being loaded.")
|
142 |
-
self.model.load_state_dict(checkpoint['state_dict'])
|
143 |
-
|
144 |
-
# load optimizer state from checkpoint only when optimizer type is not changed.
|
145 |
-
if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']:
|
146 |
-
self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. "
|
147 |
-
"Optimizer parameters not being resumed.")
|
148 |
-
else:
|
149 |
-
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
150 |
-
|
151 |
-
self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/config.json
DELETED
@@ -1,50 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"name": "Mnist_LeNet",
|
3 |
-
"n_gpu": 1,
|
4 |
-
|
5 |
-
"arch": {
|
6 |
-
"type": "MnistModel",
|
7 |
-
"args": {}
|
8 |
-
},
|
9 |
-
"data_loader": {
|
10 |
-
"type": "MnistDataLoader",
|
11 |
-
"args":{
|
12 |
-
"data_dir": "data/",
|
13 |
-
"batch_size": 128,
|
14 |
-
"shuffle": true,
|
15 |
-
"validation_split": 0.1,
|
16 |
-
"num_workers": 2
|
17 |
-
}
|
18 |
-
},
|
19 |
-
"optimizer": {
|
20 |
-
"type": "Adam",
|
21 |
-
"args":{
|
22 |
-
"lr": 0.001,
|
23 |
-
"weight_decay": 0,
|
24 |
-
"amsgrad": true
|
25 |
-
}
|
26 |
-
},
|
27 |
-
"loss": "nll_loss",
|
28 |
-
"metrics": [
|
29 |
-
"accuracy", "top_k_acc"
|
30 |
-
],
|
31 |
-
"lr_scheduler": {
|
32 |
-
"type": "StepLR",
|
33 |
-
"args": {
|
34 |
-
"step_size": 50,
|
35 |
-
"gamma": 0.1
|
36 |
-
}
|
37 |
-
},
|
38 |
-
"trainer": {
|
39 |
-
"epochs": 100,
|
40 |
-
|
41 |
-
"save_dir": "saved/",
|
42 |
-
"save_period": 1,
|
43 |
-
"verbosity": 2,
|
44 |
-
|
45 |
-
"monitor": "min val_loss",
|
46 |
-
"early_stop": 10,
|
47 |
-
|
48 |
-
"tensorboard": true
|
49 |
-
}
|
50 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/data_loader/data_loaders.py
DELETED
@@ -1,16 +0,0 @@
|
|
1 |
-
from torchvision import datasets, transforms
|
2 |
-
from base import BaseDataLoader
|
3 |
-
|
4 |
-
|
5 |
-
class MnistDataLoader(BaseDataLoader):
|
6 |
-
"""
|
7 |
-
MNIST data loading demo using BaseDataLoader
|
8 |
-
"""
|
9 |
-
def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_workers=1, training=True):
|
10 |
-
trsfm = transforms.Compose([
|
11 |
-
transforms.ToTensor(),
|
12 |
-
transforms.Normalize((0.1307,), (0.3081,))
|
13 |
-
])
|
14 |
-
self.data_dir = data_dir
|
15 |
-
self.dataset = datasets.MNIST(self.data_dir, train=training, download=True, transform=trsfm)
|
16 |
-
super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/docker-compose.yml
DELETED
@@ -1,12 +0,0 @@
|
|
1 |
-
version: "3" # optional since v1.27.0
|
2 |
-
services:
|
3 |
-
model:
|
4 |
-
build: .
|
5 |
-
ports:
|
6 |
-
- 6006:6006
|
7 |
-
volumes:
|
8 |
-
- ./saved:/usr/src/app/saved:z
|
9 |
-
# - ./out/:/usr/src/app/out:z
|
10 |
-
# - ./in/:/usr/src/app/in:z
|
11 |
-
#env_file:
|
12 |
-
# - .env
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/logger/__init__.py
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
from .logger import *
|
2 |
-
from .visualization import *
|
|
|
|
|
|
model/logger/logger.py
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
import logging.config
|
3 |
-
from pathlib import Path
|
4 |
-
from utils import read_json
|
5 |
-
|
6 |
-
|
7 |
-
def setup_logging(save_dir, log_config='logger/logger_config.json', default_level=logging.INFO):
|
8 |
-
"""
|
9 |
-
Setup logging configuration
|
10 |
-
"""
|
11 |
-
log_config = Path(log_config)
|
12 |
-
if log_config.is_file():
|
13 |
-
config = read_json(log_config)
|
14 |
-
# modify logging paths based on run config
|
15 |
-
for _, handler in config['handlers'].items():
|
16 |
-
if 'filename' in handler:
|
17 |
-
handler['filename'] = str(save_dir / handler['filename'])
|
18 |
-
|
19 |
-
logging.config.dictConfig(config)
|
20 |
-
else:
|
21 |
-
print("Warning: logging configuration file is not found in {}.".format(log_config))
|
22 |
-
logging.basicConfig(level=default_level)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/logger/logger_config.json
DELETED
@@ -1,32 +0,0 @@
|
|
1 |
-
|
2 |
-
{
|
3 |
-
"version": 1,
|
4 |
-
"disable_existing_loggers": false,
|
5 |
-
"formatters": {
|
6 |
-
"simple": {"format": "%(message)s"},
|
7 |
-
"datetime": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"}
|
8 |
-
},
|
9 |
-
"handlers": {
|
10 |
-
"console": {
|
11 |
-
"class": "logging.StreamHandler",
|
12 |
-
"level": "DEBUG",
|
13 |
-
"formatter": "simple",
|
14 |
-
"stream": "ext://sys.stdout"
|
15 |
-
},
|
16 |
-
"info_file_handler": {
|
17 |
-
"class": "logging.handlers.RotatingFileHandler",
|
18 |
-
"level": "INFO",
|
19 |
-
"formatter": "datetime",
|
20 |
-
"filename": "info.log",
|
21 |
-
"maxBytes": 10485760,
|
22 |
-
"backupCount": 20, "encoding": "utf8"
|
23 |
-
}
|
24 |
-
},
|
25 |
-
"root": {
|
26 |
-
"level": "INFO",
|
27 |
-
"handlers": [
|
28 |
-
"console",
|
29 |
-
"info_file_handler"
|
30 |
-
]
|
31 |
-
}
|
32 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/logger/visualization.py
DELETED
@@ -1,73 +0,0 @@
|
|
1 |
-
import importlib
|
2 |
-
from datetime import datetime
|
3 |
-
|
4 |
-
|
5 |
-
class TensorboardWriter():
|
6 |
-
def __init__(self, log_dir, logger, enabled):
|
7 |
-
self.writer = None
|
8 |
-
self.selected_module = ""
|
9 |
-
|
10 |
-
if enabled:
|
11 |
-
log_dir = str(log_dir)
|
12 |
-
|
13 |
-
# Retrieve vizualization writer.
|
14 |
-
succeeded = False
|
15 |
-
for module in ["torch.utils.tensorboard", "tensorboardX"]:
|
16 |
-
try:
|
17 |
-
self.writer = importlib.import_module(module).SummaryWriter(log_dir)
|
18 |
-
succeeded = True
|
19 |
-
break
|
20 |
-
except ImportError:
|
21 |
-
succeeded = False
|
22 |
-
self.selected_module = module
|
23 |
-
|
24 |
-
if not succeeded:
|
25 |
-
message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \
|
26 |
-
"this machine. Please install TensorboardX with 'pip install tensorboardx', upgrade PyTorch to " \
|
27 |
-
"version >= 1.1 to use 'torch.utils.tensorboard' or turn off the option in the 'config.json' file."
|
28 |
-
logger.warning(message)
|
29 |
-
|
30 |
-
self.step = 0
|
31 |
-
self.mode = ''
|
32 |
-
|
33 |
-
self.tb_writer_ftns = {
|
34 |
-
'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio',
|
35 |
-
'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding'
|
36 |
-
}
|
37 |
-
self.tag_mode_exceptions = {'add_histogram', 'add_embedding'}
|
38 |
-
self.timer = datetime.now()
|
39 |
-
|
40 |
-
def set_step(self, step, mode='train'):
|
41 |
-
self.mode = mode
|
42 |
-
self.step = step
|
43 |
-
if step == 0:
|
44 |
-
self.timer = datetime.now()
|
45 |
-
else:
|
46 |
-
duration = datetime.now() - self.timer
|
47 |
-
self.add_scalar('steps_per_sec', 1 / duration.total_seconds())
|
48 |
-
self.timer = datetime.now()
|
49 |
-
|
50 |
-
def __getattr__(self, name):
|
51 |
-
"""
|
52 |
-
If visualization is configured to use:
|
53 |
-
return add_data() methods of tensorboard with additional information (step, tag) added.
|
54 |
-
Otherwise:
|
55 |
-
return a blank function handle that does nothing
|
56 |
-
"""
|
57 |
-
if name in self.tb_writer_ftns:
|
58 |
-
add_data = getattr(self.writer, name, None)
|
59 |
-
|
60 |
-
def wrapper(tag, data, *args, **kwargs):
|
61 |
-
if add_data is not None:
|
62 |
-
# add mode(train/valid) tag
|
63 |
-
if name not in self.tag_mode_exceptions:
|
64 |
-
tag = '{}/{}'.format(tag, self.mode)
|
65 |
-
add_data(tag, data, self.step, *args, **kwargs)
|
66 |
-
return wrapper
|
67 |
-
else:
|
68 |
-
# default action for returning methods defined in this class, set_step() for instance.
|
69 |
-
try:
|
70 |
-
attr = object.__getattr__(name)
|
71 |
-
except AttributeError:
|
72 |
-
raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name))
|
73 |
-
return attr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/model/loss.py
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
import torch.nn.functional as F
|
2 |
-
|
3 |
-
|
4 |
-
def nll_loss(output, target):
|
5 |
-
return F.nll_loss(output, target)
|
|
|
|
|
|
|
|
|
|
|
|
model/model/metric.py
DELETED
@@ -1,20 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
|
3 |
-
|
4 |
-
def accuracy(output, target):
|
5 |
-
with torch.no_grad():
|
6 |
-
pred = torch.argmax(output, dim=1)
|
7 |
-
assert pred.shape[0] == len(target)
|
8 |
-
correct = 0
|
9 |
-
correct += torch.sum(pred == target).item()
|
10 |
-
return correct / len(target)
|
11 |
-
|
12 |
-
|
13 |
-
def top_k_acc(output, target, k=3):
|
14 |
-
with torch.no_grad():
|
15 |
-
pred = torch.topk(output, k, dim=1)[1]
|
16 |
-
assert pred.shape[0] == len(target)
|
17 |
-
correct = 0
|
18 |
-
for i in range(k):
|
19 |
-
correct += torch.sum(pred[:, i] == target).item()
|
20 |
-
return correct / len(target)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/model/model.py
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
import torch.nn as nn
|
2 |
-
import torch.nn.functional as F
|
3 |
-
from base import BaseModel
|
4 |
-
|
5 |
-
|
6 |
-
class MnistModel(BaseModel):
|
7 |
-
def __init__(self, num_classes=10):
|
8 |
-
super().__init__()
|
9 |
-
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
10 |
-
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
11 |
-
self.conv2_drop = nn.Dropout2d()
|
12 |
-
self.fc1 = nn.Linear(320, 50)
|
13 |
-
self.fc2 = nn.Linear(50, num_classes)
|
14 |
-
|
15 |
-
def forward(self, x):
|
16 |
-
x = F.relu(F.max_pool2d(self.conv1(x), 2))
|
17 |
-
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
|
18 |
-
x = x.view(-1, 320)
|
19 |
-
x = F.relu(self.fc1(x))
|
20 |
-
x = F.dropout(x, training=self.training)
|
21 |
-
x = self.fc2(x)
|
22 |
-
return F.log_softmax(x, dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/new_project.py
DELETED
@@ -1,18 +0,0 @@
|
|
1 |
-
import sys
|
2 |
-
from pathlib import Path
|
3 |
-
from shutil import copytree, ignore_patterns
|
4 |
-
|
5 |
-
|
6 |
-
# This script initializes new pytorch project with the template files.
|
7 |
-
# Run `python3 new_project.py ../MyNewProject` then new project named
|
8 |
-
# MyNewProject will be made
|
9 |
-
current_dir = Path()
|
10 |
-
assert (current_dir / 'new_project.py').is_file(), 'Script should be executed in the pytorch-template directory'
|
11 |
-
assert len(sys.argv) == 2, 'Specify a name for the new project. Example: python3 new_project.py MyNewProject'
|
12 |
-
|
13 |
-
project_name = Path(sys.argv[1])
|
14 |
-
target_dir = current_dir / project_name
|
15 |
-
|
16 |
-
ignore = [".git", "data", "saved", "new_project.py", "LICENSE", ".flake8", "README.md", "__pycache__"]
|
17 |
-
copytree(current_dir, target_dir, ignore=ignore_patterns(*ignore))
|
18 |
-
print('New project initialized at', target_dir.absolute().resolve())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/parse_config.py
DELETED
@@ -1,157 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import logging
|
3 |
-
from pathlib import Path
|
4 |
-
from functools import reduce, partial
|
5 |
-
from operator import getitem
|
6 |
-
from datetime import datetime
|
7 |
-
from logger import setup_logging
|
8 |
-
from utils import read_json, write_json
|
9 |
-
|
10 |
-
|
11 |
-
class ConfigParser:
|
12 |
-
def __init__(self, config, resume=None, modification=None, run_id=None):
|
13 |
-
"""
|
14 |
-
class to parse configuration json file. Handles hyperparameters for training, initializations of modules, checkpoint saving
|
15 |
-
and logging module.
|
16 |
-
:param config: Dict containing configurations, hyperparameters for training. contents of `config.json` file for example.
|
17 |
-
:param resume: String, path to the checkpoint being loaded.
|
18 |
-
:param modification: Dict keychain:value, specifying position values to be replaced from config dict.
|
19 |
-
:param run_id: Unique Identifier for training processes. Used to save checkpoints and training log. Timestamp is being used as default
|
20 |
-
"""
|
21 |
-
# load config file and apply modification
|
22 |
-
self._config = _update_config(config, modification)
|
23 |
-
self.resume = resume
|
24 |
-
|
25 |
-
# set save_dir where trained model and log will be saved.
|
26 |
-
save_dir = Path(self.config['trainer']['save_dir'])
|
27 |
-
|
28 |
-
exper_name = self.config['name']
|
29 |
-
if run_id is None: # use timestamp as default run-id
|
30 |
-
run_id = datetime.now().strftime(r'%m%d_%H%M%S')
|
31 |
-
self._save_dir = save_dir / 'models' / exper_name / run_id
|
32 |
-
self._log_dir = save_dir / 'log' / exper_name / run_id
|
33 |
-
|
34 |
-
# make directory for saving checkpoints and log.
|
35 |
-
exist_ok = run_id == ''
|
36 |
-
self.save_dir.mkdir(parents=True, exist_ok=exist_ok)
|
37 |
-
self.log_dir.mkdir(parents=True, exist_ok=exist_ok)
|
38 |
-
|
39 |
-
# save updated config file to the checkpoint dir
|
40 |
-
write_json(self.config, self.save_dir / 'config.json')
|
41 |
-
|
42 |
-
# configure logging module
|
43 |
-
setup_logging(self.log_dir)
|
44 |
-
self.log_levels = {
|
45 |
-
0: logging.WARNING,
|
46 |
-
1: logging.INFO,
|
47 |
-
2: logging.DEBUG
|
48 |
-
}
|
49 |
-
|
50 |
-
@classmethod
|
51 |
-
def from_args(cls, args, options=''):
|
52 |
-
"""
|
53 |
-
Initialize this class from some cli arguments. Used in train, test.
|
54 |
-
"""
|
55 |
-
for opt in options:
|
56 |
-
args.add_argument(*opt.flags, default=None, type=opt.type)
|
57 |
-
if not isinstance(args, tuple):
|
58 |
-
args = args.parse_args()
|
59 |
-
|
60 |
-
if args.device is not None:
|
61 |
-
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
|
62 |
-
if args.resume is not None:
|
63 |
-
resume = Path(args.resume)
|
64 |
-
cfg_fname = resume.parent / 'config.json'
|
65 |
-
else:
|
66 |
-
msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example."
|
67 |
-
assert args.config is not None, msg_no_cfg
|
68 |
-
resume = None
|
69 |
-
cfg_fname = Path(args.config)
|
70 |
-
|
71 |
-
config = read_json(cfg_fname)
|
72 |
-
if args.config and resume:
|
73 |
-
# update new config for fine-tuning
|
74 |
-
config.update(read_json(args.config))
|
75 |
-
|
76 |
-
# parse custom cli options into dictionary
|
77 |
-
modification = {opt.target : getattr(args, _get_opt_name(opt.flags)) for opt in options}
|
78 |
-
return cls(config, resume, modification)
|
79 |
-
|
80 |
-
def init_obj(self, name, module, *args, **kwargs):
|
81 |
-
"""
|
82 |
-
Finds a function handle with the name given as 'type' in config, and returns the
|
83 |
-
instance initialized with corresponding arguments given.
|
84 |
-
|
85 |
-
`object = config.init_obj('name', module, a, b=1)`
|
86 |
-
is equivalent to
|
87 |
-
`object = module.name(a, b=1)`
|
88 |
-
"""
|
89 |
-
module_name = self[name]['type']
|
90 |
-
module_args = dict(self[name]['args'])
|
91 |
-
assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
|
92 |
-
module_args.update(kwargs)
|
93 |
-
return getattr(module, module_name)(*args, **module_args)
|
94 |
-
|
95 |
-
def init_ftn(self, name, module, *args, **kwargs):
|
96 |
-
"""
|
97 |
-
Finds a function handle with the name given as 'type' in config, and returns the
|
98 |
-
function with given arguments fixed with functools.partial.
|
99 |
-
|
100 |
-
`function = config.init_ftn('name', module, a, b=1)`
|
101 |
-
is equivalent to
|
102 |
-
`function = lambda *args, **kwargs: module.name(a, *args, b=1, **kwargs)`.
|
103 |
-
"""
|
104 |
-
module_name = self[name]['type']
|
105 |
-
module_args = dict(self[name]['args'])
|
106 |
-
assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
|
107 |
-
module_args.update(kwargs)
|
108 |
-
return partial(getattr(module, module_name), *args, **module_args)
|
109 |
-
|
110 |
-
def __getitem__(self, name):
|
111 |
-
"""Access items like ordinary dict."""
|
112 |
-
return self.config[name]
|
113 |
-
|
114 |
-
def get_logger(self, name, verbosity=2):
|
115 |
-
msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity, self.log_levels.keys())
|
116 |
-
assert verbosity in self.log_levels, msg_verbosity
|
117 |
-
logger = logging.getLogger(name)
|
118 |
-
logger.setLevel(self.log_levels[verbosity])
|
119 |
-
return logger
|
120 |
-
|
121 |
-
# setting read-only attributes
|
122 |
-
@property
|
123 |
-
def config(self):
|
124 |
-
return self._config
|
125 |
-
|
126 |
-
@property
|
127 |
-
def save_dir(self):
|
128 |
-
return self._save_dir
|
129 |
-
|
130 |
-
@property
|
131 |
-
def log_dir(self):
|
132 |
-
return self._log_dir
|
133 |
-
|
134 |
-
# helper functions to update config dict with custom cli options
|
135 |
-
def _update_config(config, modification):
|
136 |
-
if modification is None:
|
137 |
-
return config
|
138 |
-
|
139 |
-
for k, v in modification.items():
|
140 |
-
if v is not None:
|
141 |
-
_set_by_path(config, k, v)
|
142 |
-
return config
|
143 |
-
|
144 |
-
def _get_opt_name(flags):
|
145 |
-
for flg in flags:
|
146 |
-
if flg.startswith('--'):
|
147 |
-
return flg.replace('--', '')
|
148 |
-
return flags[0].replace('--', '')
|
149 |
-
|
150 |
-
def _set_by_path(tree, keys, value):
|
151 |
-
"""Set a value in a nested object in tree by sequence of keys."""
|
152 |
-
keys = keys.split(';')
|
153 |
-
_get_by_path(tree, keys[:-1])[keys[-1]] = value
|
154 |
-
|
155 |
-
def _get_by_path(tree, keys):
|
156 |
-
"""Access a nested object in tree by sequence of keys."""
|
157 |
-
return reduce(getitem, keys, tree)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/requirements.txt
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
torch>=1.1
|
2 |
-
torchvision
|
3 |
-
numpy
|
4 |
-
tqdm
|
5 |
-
tensorboardx
|
6 |
-
pandas
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/test.py
DELETED
@@ -1,81 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import torch
|
3 |
-
from tqdm import tqdm
|
4 |
-
import data_loader.data_loaders as module_data
|
5 |
-
import model.loss as module_loss
|
6 |
-
import model.metric as module_metric
|
7 |
-
import model.model as module_arch
|
8 |
-
from parse_config import ConfigParser
|
9 |
-
|
10 |
-
|
11 |
-
def main(config):
|
12 |
-
logger = config.get_logger('test')
|
13 |
-
|
14 |
-
# setup data_loader instances
|
15 |
-
data_loader = getattr(module_data, config['data_loader']['type'])(
|
16 |
-
config['data_loader']['args']['data_dir'],
|
17 |
-
batch_size=512,
|
18 |
-
shuffle=False,
|
19 |
-
validation_split=0.0,
|
20 |
-
training=False,
|
21 |
-
num_workers=2
|
22 |
-
)
|
23 |
-
|
24 |
-
# build model architecture
|
25 |
-
model = config.init_obj('arch', module_arch)
|
26 |
-
logger.info(model)
|
27 |
-
|
28 |
-
# get function handles of loss and metrics
|
29 |
-
loss_fn = getattr(module_loss, config['loss'])
|
30 |
-
metric_fns = [getattr(module_metric, met) for met in config['metrics']]
|
31 |
-
|
32 |
-
logger.info('Loading checkpoint: {} ...'.format(config.resume))
|
33 |
-
checkpoint = torch.load(config.resume)
|
34 |
-
state_dict = checkpoint['state_dict']
|
35 |
-
if config['n_gpu'] > 1:
|
36 |
-
model = torch.nn.DataParallel(model)
|
37 |
-
model.load_state_dict(state_dict)
|
38 |
-
|
39 |
-
# prepare model for testing
|
40 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
41 |
-
model = model.to(device)
|
42 |
-
model.eval()
|
43 |
-
|
44 |
-
total_loss = 0.0
|
45 |
-
total_metrics = torch.zeros(len(metric_fns))
|
46 |
-
|
47 |
-
with torch.no_grad():
|
48 |
-
for i, (data, target) in enumerate(tqdm(data_loader)):
|
49 |
-
data, target = data.to(device), target.to(device)
|
50 |
-
output = model(data)
|
51 |
-
|
52 |
-
#
|
53 |
-
# save sample images, or do something with output here
|
54 |
-
#
|
55 |
-
|
56 |
-
# computing loss, metrics on test set
|
57 |
-
loss = loss_fn(output, target)
|
58 |
-
batch_size = data.shape[0]
|
59 |
-
total_loss += loss.item() * batch_size
|
60 |
-
for i, metric in enumerate(metric_fns):
|
61 |
-
total_metrics[i] += metric(output, target) * batch_size
|
62 |
-
|
63 |
-
n_samples = len(data_loader.sampler)
|
64 |
-
log = {'loss': total_loss / n_samples}
|
65 |
-
log.update({
|
66 |
-
met.__name__: total_metrics[i].item() / n_samples for i, met in enumerate(metric_fns)
|
67 |
-
})
|
68 |
-
logger.info(log)
|
69 |
-
|
70 |
-
|
71 |
-
if __name__ == '__main__':
|
72 |
-
args = argparse.ArgumentParser(description='PyTorch Template')
|
73 |
-
args.add_argument('-c', '--config', default=None, type=str,
|
74 |
-
help='config file path (default: None)')
|
75 |
-
args.add_argument('-r', '--resume', default=None, type=str,
|
76 |
-
help='path to latest checkpoint (default: None)')
|
77 |
-
args.add_argument('-d', '--device', default=None, type=str,
|
78 |
-
help='indices of GPUs to enable (default: all)')
|
79 |
-
|
80 |
-
config = ConfigParser.from_args(args)
|
81 |
-
main(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/train.py
DELETED
@@ -1,73 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import collections
|
3 |
-
import torch
|
4 |
-
import numpy as np
|
5 |
-
import data_loader.data_loaders as module_data
|
6 |
-
import model.loss as module_loss
|
7 |
-
import model.metric as module_metric
|
8 |
-
import model.model as module_arch
|
9 |
-
from parse_config import ConfigParser
|
10 |
-
from trainer import Trainer
|
11 |
-
from utils import prepare_device
|
12 |
-
|
13 |
-
|
14 |
-
# fix random seeds for reproducibility
|
15 |
-
SEED = 123
|
16 |
-
torch.manual_seed(SEED)
|
17 |
-
torch.backends.cudnn.deterministic = True
|
18 |
-
torch.backends.cudnn.benchmark = False
|
19 |
-
np.random.seed(SEED)
|
20 |
-
|
21 |
-
def main(config):
|
22 |
-
logger = config.get_logger('train')
|
23 |
-
|
24 |
-
# setup data_loader instances
|
25 |
-
data_loader = config.init_obj('data_loader', module_data)
|
26 |
-
valid_data_loader = data_loader.split_validation()
|
27 |
-
|
28 |
-
# build model architecture, then print to console
|
29 |
-
model = config.init_obj('arch', module_arch)
|
30 |
-
logger.info(model)
|
31 |
-
|
32 |
-
# prepare for (multi-device) GPU training
|
33 |
-
device, device_ids = prepare_device(config['n_gpu'])
|
34 |
-
model = model.to(device)
|
35 |
-
if len(device_ids) > 1:
|
36 |
-
model = torch.nn.DataParallel(model, device_ids=device_ids)
|
37 |
-
|
38 |
-
# get function handles of loss and metrics
|
39 |
-
criterion = getattr(module_loss, config['loss'])
|
40 |
-
metrics = [getattr(module_metric, met) for met in config['metrics']]
|
41 |
-
|
42 |
-
# build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
|
43 |
-
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
|
44 |
-
optimizer = config.init_obj('optimizer', torch.optim, trainable_params)
|
45 |
-
lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer)
|
46 |
-
|
47 |
-
trainer = Trainer(model, criterion, metrics, optimizer,
|
48 |
-
config=config,
|
49 |
-
device=device,
|
50 |
-
data_loader=data_loader,
|
51 |
-
valid_data_loader=valid_data_loader,
|
52 |
-
lr_scheduler=lr_scheduler)
|
53 |
-
|
54 |
-
trainer.train()
|
55 |
-
|
56 |
-
|
57 |
-
if __name__ == '__main__':
|
58 |
-
args = argparse.ArgumentParser(description='PyTorch Template')
|
59 |
-
args.add_argument('-c', '--config', default=None, type=str,
|
60 |
-
help='config file path (default: None)')
|
61 |
-
args.add_argument('-r', '--resume', default=None, type=str,
|
62 |
-
help='path to latest checkpoint (default: None)')
|
63 |
-
args.add_argument('-d', '--device', default=None, type=str,
|
64 |
-
help='indices of GPUs to enable (default: all)')
|
65 |
-
|
66 |
-
# custom cli options to modify configuration from default values given in json file.
|
67 |
-
CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
|
68 |
-
options = [
|
69 |
-
CustomArgs(['--lr', '--learning_rate'], type=float, target='optimizer;args;lr'),
|
70 |
-
CustomArgs(['--bs', '--batch_size'], type=int, target='data_loader;args;batch_size')
|
71 |
-
]
|
72 |
-
config = ConfigParser.from_args(args, options)
|
73 |
-
main(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/trainer/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from .trainer import *
|
|
|
|
model/trainer/trainer.py
DELETED
@@ -1,110 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import torch
|
3 |
-
from torchvision.utils import make_grid
|
4 |
-
from base import BaseTrainer
|
5 |
-
from utils import inf_loop, MetricTracker
|
6 |
-
|
7 |
-
|
8 |
-
class Trainer(BaseTrainer):
|
9 |
-
"""
|
10 |
-
Trainer class
|
11 |
-
"""
|
12 |
-
def __init__(self, model, criterion, metric_ftns, optimizer, config, device,
|
13 |
-
data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None):
|
14 |
-
super().__init__(model, criterion, metric_ftns, optimizer, config)
|
15 |
-
self.config = config
|
16 |
-
self.device = device
|
17 |
-
self.data_loader = data_loader
|
18 |
-
if len_epoch is None:
|
19 |
-
# epoch-based training
|
20 |
-
self.len_epoch = len(self.data_loader)
|
21 |
-
else:
|
22 |
-
# iteration-based training
|
23 |
-
self.data_loader = inf_loop(data_loader)
|
24 |
-
self.len_epoch = len_epoch
|
25 |
-
self.valid_data_loader = valid_data_loader
|
26 |
-
self.do_validation = self.valid_data_loader is not None
|
27 |
-
self.lr_scheduler = lr_scheduler
|
28 |
-
self.log_step = int(np.sqrt(data_loader.batch_size))
|
29 |
-
|
30 |
-
self.train_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)
|
31 |
-
self.valid_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)
|
32 |
-
|
33 |
-
def _train_epoch(self, epoch):
|
34 |
-
"""
|
35 |
-
Training logic for an epoch
|
36 |
-
|
37 |
-
:param epoch: Integer, current training epoch.
|
38 |
-
:return: A log that contains average loss and metric in this epoch.
|
39 |
-
"""
|
40 |
-
self.model.train()
|
41 |
-
self.train_metrics.reset()
|
42 |
-
for batch_idx, (data, target) in enumerate(self.data_loader):
|
43 |
-
data, target = data.to(self.device), target.to(self.device)
|
44 |
-
|
45 |
-
self.optimizer.zero_grad()
|
46 |
-
output = self.model(data)
|
47 |
-
loss = self.criterion(output, target)
|
48 |
-
loss.backward()
|
49 |
-
self.optimizer.step()
|
50 |
-
|
51 |
-
self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
|
52 |
-
self.train_metrics.update('loss', loss.item())
|
53 |
-
for met in self.metric_ftns:
|
54 |
-
self.train_metrics.update(met.__name__, met(output, target))
|
55 |
-
|
56 |
-
if batch_idx % self.log_step == 0:
|
57 |
-
self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
|
58 |
-
epoch,
|
59 |
-
self._progress(batch_idx),
|
60 |
-
loss.item()))
|
61 |
-
self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
|
62 |
-
|
63 |
-
if batch_idx == self.len_epoch:
|
64 |
-
break
|
65 |
-
log = self.train_metrics.result()
|
66 |
-
|
67 |
-
if self.do_validation:
|
68 |
-
val_log = self._valid_epoch(epoch)
|
69 |
-
log.update(**{'val_'+k : v for k, v in val_log.items()})
|
70 |
-
|
71 |
-
if self.lr_scheduler is not None:
|
72 |
-
self.lr_scheduler.step()
|
73 |
-
return log
|
74 |
-
|
75 |
-
def _valid_epoch(self, epoch):
|
76 |
-
"""
|
77 |
-
Validate after training an epoch
|
78 |
-
|
79 |
-
:param epoch: Integer, current training epoch.
|
80 |
-
:return: A log that contains information about validation
|
81 |
-
"""
|
82 |
-
self.model.eval()
|
83 |
-
self.valid_metrics.reset()
|
84 |
-
with torch.no_grad():
|
85 |
-
for batch_idx, (data, target) in enumerate(self.valid_data_loader):
|
86 |
-
data, target = data.to(self.device), target.to(self.device)
|
87 |
-
|
88 |
-
output = self.model(data)
|
89 |
-
loss = self.criterion(output, target)
|
90 |
-
|
91 |
-
self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid')
|
92 |
-
self.valid_metrics.update('loss', loss.item())
|
93 |
-
for met in self.metric_ftns:
|
94 |
-
self.valid_metrics.update(met.__name__, met(output, target))
|
95 |
-
self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
|
96 |
-
|
97 |
-
# add histogram of model parameters to the tensorboard
|
98 |
-
for name, p in self.model.named_parameters():
|
99 |
-
self.writer.add_histogram(name, p, bins='auto')
|
100 |
-
return self.valid_metrics.result()
|
101 |
-
|
102 |
-
def _progress(self, batch_idx):
|
103 |
-
base = '[{}/{} ({:.0f}%)]'
|
104 |
-
if hasattr(self.data_loader, 'n_samples'):
|
105 |
-
current = batch_idx * self.data_loader.batch_size
|
106 |
-
total = self.data_loader.n_samples
|
107 |
-
else:
|
108 |
-
current = batch_idx
|
109 |
-
total = self.len_epoch
|
110 |
-
return base.format(current, total, 100.0 * current / total)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/utils/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from .util import *
|
|
|
|
model/utils/util.py
DELETED
@@ -1,67 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
import torch
|
3 |
-
import pandas as pd
|
4 |
-
from pathlib import Path
|
5 |
-
from itertools import repeat
|
6 |
-
from collections import OrderedDict
|
7 |
-
|
8 |
-
|
9 |
-
def ensure_dir(dirname):
|
10 |
-
dirname = Path(dirname)
|
11 |
-
if not dirname.is_dir():
|
12 |
-
dirname.mkdir(parents=True, exist_ok=False)
|
13 |
-
|
14 |
-
def read_json(fname):
|
15 |
-
fname = Path(fname)
|
16 |
-
with fname.open('rt') as handle:
|
17 |
-
return json.load(handle, object_hook=OrderedDict)
|
18 |
-
|
19 |
-
def write_json(content, fname):
|
20 |
-
fname = Path(fname)
|
21 |
-
with fname.open('wt') as handle:
|
22 |
-
json.dump(content, handle, indent=4, sort_keys=False)
|
23 |
-
|
24 |
-
def inf_loop(data_loader):
|
25 |
-
''' wrapper function for endless data loader. '''
|
26 |
-
for loader in repeat(data_loader):
|
27 |
-
yield from loader
|
28 |
-
|
29 |
-
def prepare_device(n_gpu_use):
|
30 |
-
"""
|
31 |
-
setup GPU device if available. get gpu device indices which are used for DataParallel
|
32 |
-
"""
|
33 |
-
n_gpu = torch.cuda.device_count()
|
34 |
-
if n_gpu_use > 0 and n_gpu == 0:
|
35 |
-
print("Warning: There\'s no GPU available on this machine,"
|
36 |
-
"training will be performed on CPU.")
|
37 |
-
n_gpu_use = 0
|
38 |
-
if n_gpu_use > n_gpu:
|
39 |
-
print(f"Warning: The number of GPU\'s configured to use is {n_gpu_use}, but only {n_gpu} are "
|
40 |
-
"available on this machine.")
|
41 |
-
n_gpu_use = n_gpu
|
42 |
-
device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
|
43 |
-
list_ids = list(range(n_gpu_use))
|
44 |
-
return device, list_ids
|
45 |
-
|
46 |
-
class MetricTracker:
|
47 |
-
def __init__(self, *keys, writer=None):
|
48 |
-
self.writer = writer
|
49 |
-
self._data = pd.DataFrame(index=keys, columns=['total', 'counts', 'average'])
|
50 |
-
self.reset()
|
51 |
-
|
52 |
-
def reset(self):
|
53 |
-
for col in self._data.columns:
|
54 |
-
self._data[col].values[:] = 0
|
55 |
-
|
56 |
-
def update(self, key, value, n=1):
|
57 |
-
if self.writer is not None:
|
58 |
-
self.writer.add_scalar(key, value)
|
59 |
-
self._data.total[key] += value * n
|
60 |
-
self._data.counts[key] += n
|
61 |
-
self._data.average[key] = self._data.total[key] / self._data.counts[key]
|
62 |
-
|
63 |
-
def avg(self, key):
|
64 |
-
return self._data.average[key]
|
65 |
-
|
66 |
-
def result(self):
|
67 |
-
return dict(self._data.average)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/.env.example
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# this is example of the file that can be used for storing private and user specific environment variables, like keys or system paths
|
2 |
+
# create a file named .env (by default .env will be excluded from version control)
|
3 |
+
# the variables declared in .env are loaded in train.py automatically
|
4 |
+
# hydra allows you to reference variables in .yaml configs with special syntax: ${oc.env:MY_VAR}
|
5 |
+
|
6 |
+
MY_VAR="/home/user/my/system/path"
|
7 |
+
MY_KEY="asdgjhawi8y23ihsghsueity23ihwd"
|
models/.gitignore
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.venv
|
106 |
+
env/
|
107 |
+
venv/
|
108 |
+
ENV/
|
109 |
+
env.bak/
|
110 |
+
venv.bak/
|
111 |
+
|
112 |
+
# Spyder project settings
|
113 |
+
.spyderproject
|
114 |
+
.spyproject
|
115 |
+
|
116 |
+
# Rope project settings
|
117 |
+
.ropeproject
|
118 |
+
|
119 |
+
# mkdocs documentation
|
120 |
+
/site
|
121 |
+
|
122 |
+
# mypy
|
123 |
+
.mypy_cache/
|
124 |
+
.dmypy.json
|
125 |
+
dmypy.json
|
126 |
+
|
127 |
+
# Pyre type checker
|
128 |
+
.pyre/
|
129 |
+
|
130 |
+
### VisualStudioCode
|
131 |
+
.vscode/*
|
132 |
+
!.vscode/settings.json
|
133 |
+
!.vscode/tasks.json
|
134 |
+
!.vscode/launch.json
|
135 |
+
!.vscode/extensions.json
|
136 |
+
*.code-workspace
|
137 |
+
**/.vscode
|
138 |
+
|
139 |
+
# JetBrains
|
140 |
+
.idea/
|
141 |
+
|
142 |
+
# Lightning-Hydra-Template
|
143 |
+
configs/local/default.yaml
|
144 |
+
data/
|
145 |
+
logs/
|
146 |
+
wandb/
|
147 |
+
.env
|
148 |
+
.autoenv
|
models/.pre-commit-config.yaml
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
repos:
|
2 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
3 |
+
rev: v4.1.0
|
4 |
+
hooks:
|
5 |
+
# list of supported hooks: https://pre-commit.com/hooks.html
|
6 |
+
- id: trailing-whitespace
|
7 |
+
- id: end-of-file-fixer
|
8 |
+
- id: check-yaml
|
9 |
+
- id: check-added-large-files
|
10 |
+
- id: debug-statements
|
11 |
+
- id: detect-private-key
|
12 |
+
|
13 |
+
# python code formatting
|
14 |
+
- repo: https://github.com/psf/black
|
15 |
+
rev: 22.1.0
|
16 |
+
hooks:
|
17 |
+
- id: black
|
18 |
+
args: [--line-length, "99"]
|
19 |
+
|
20 |
+
# python import sorting
|
21 |
+
- repo: https://github.com/PyCQA/isort
|
22 |
+
rev: 5.10.1
|
23 |
+
hooks:
|
24 |
+
- id: isort
|
25 |
+
args: ["--profile", "black", "--filter-files"]
|
26 |
+
|
27 |
+
# yaml formatting
|
28 |
+
- repo: https://github.com/pre-commit/mirrors-prettier
|
29 |
+
rev: v2.5.1
|
30 |
+
hooks:
|
31 |
+
- id: prettier
|
32 |
+
types: [yaml]
|
33 |
+
|
34 |
+
# python code analysis
|
35 |
+
- repo: https://github.com/PyCQA/flake8
|
36 |
+
rev: 4.0.1
|
37 |
+
hooks:
|
38 |
+
- id: flake8
|
39 |
+
|
40 |
+
# jupyter notebook cell output clearing
|
41 |
+
- repo: https://github.com/kynan/nbstripout
|
42 |
+
rev: 0.5.0
|
43 |
+
hooks:
|
44 |
+
- id: nbstripout
|
models/README.md
ADDED
@@ -0,0 +1,1445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
|
3 |
+
# Lightning-Hydra-Template
|
4 |
+
|
5 |
+
<a href="https://www.python.org/"><img alt="Python" src="https://img.shields.io/badge/-Python 3.7+-blue?style=for-the-badge&logo=python&logoColor=white"></a>
|
6 |
+
<a href="https://pytorch.org/get-started/locally/"><img alt="PyTorch" src="https://img.shields.io/badge/-PyTorch 1.8+-ee4c2c?style=for-the-badge&logo=pytorch&logoColor=white"></a>
|
7 |
+
<a href="https://pytorchlightning.ai/"><img alt="Lightning" src="https://img.shields.io/badge/-Lightning 1.5+-792ee5?style=for-the-badge&logo=pytorchlightning&logoColor=white"></a>
|
8 |
+
<a href="https://hydra.cc/"><img alt="Config: hydra" src="https://img.shields.io/badge/config-hydra 1.1-89b8cd?style=for-the-badge&labelColor=gray"></a>
|
9 |
+
<a href="https://black.readthedocs.io/en/stable/"><img alt="Code style: black" src="https://img.shields.io/badge/code%20style-black-black.svg?style=for-the-badge&labelColor=gray"></a>
|
10 |
+
|
11 |
+
A clean and scalable template to kickstart your deep learning project 🚀⚡🔥<br>
|
12 |
+
Click on [<kbd>Use this template</kbd>](https://github.com/ashleve/lightning-hydra-template/generate) to initialize new repository.
|
13 |
+
|
14 |
+
_Suggestions are always welcome!_
|
15 |
+
|
16 |
+
</div>
|
17 |
+
|
18 |
+
<br><br>
|
19 |
+
|
20 |
+
## 📌 Introduction
|
21 |
+
|
22 |
+
This template tries to be as general as possible. It integrates many different MLOps tools.
|
23 |
+
|
24 |
+
> Effective usage of this template requires learning of a couple of technologies: [PyTorch](https://pytorch.org), [PyTorch Lightning](https://www.pytorchlightning.ai) and [Hydra](https://hydra.cc). Knowledge of some experiment logging framework like [Weights&Biases](https://wandb.com), [Neptune](https://neptune.ai) or [MLFlow](https://mlflow.org) is also recommended.
|
25 |
+
|
26 |
+
**Why you should use it:** it allows you to rapidly iterate over new models/datasets and scale your projects from small single experiments to hyperparameter searches on computing clusters, without writing any boilerplate code. To my knowledge, it's one of the most convenient all-in-one technology stack for Deep Learning research. Good starting point for reproducing papers, kaggle competitions or small-team research projects. It's also a collection of best practices for efficient workflow and reproducibility.
|
27 |
+
|
28 |
+
**Why you shouldn't use it:** this template is not fitted to be a production environment, should be used more as a fast experimentation tool. Apart from that, Lightning and Hydra are still evolving and integrate many libraries, which means sometimes things break - for the list of currently known bugs, visit [this page](https://github.com/ashleve/lightning-hydra-template/labels/bug). Also, even though Lightning is very flexible, it's not well suited for every possible deep learning task. See [#Limitations](#limitations) for more.
|
29 |
+
|
30 |
+
### Why PyTorch Lightning?
|
31 |
+
|
32 |
+
[PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning) is a lightweight PyTorch wrapper for high-performance AI research.
|
33 |
+
It makes your code neatly organized and provides lots of useful features, like ability to run model on CPU, GPU, multi-GPU cluster and TPU.
|
34 |
+
|
35 |
+
### Why Hydra?
|
36 |
+
|
37 |
+
[Hydra](https://github.com/facebookresearch/hydra) is an open-source Python framework that simplifies the development of research and other complex applications. The key feature is the ability to dynamically create a hierarchical configuration by composition and override it through config files and the command line. It allows you to conveniently manage experiments and provides many useful plugins, like [Optuna Sweeper](https://hydra.cc/docs/next/plugins/optuna_sweeper) for hyperparameter search, or [Ray Launcher](https://hydra.cc/docs/next/plugins/ray_launcher) for running jobs on a cluster.
|
38 |
+
|
39 |
+
<br>
|
40 |
+
|
41 |
+
## Main Ideas Of This Template
|
42 |
+
|
43 |
+
- **Predefined Structure**: clean and scalable so that work can easily be extended and replicated | [#Project Structure](#project-structure)
|
44 |
+
- **Rapid Experimentation**: thanks to automating pipeline with config files and hydra command line superpowers | [#Your Superpowers](#your-superpowers)
|
45 |
+
- **Reproducibility**: obtaining similar results is supported in multiple ways | [#Reproducibility](#reproducibility)
|
46 |
+
- **Little Boilerplate**: so pipeline can be easily modified | [#How It Works](#how-it-works)
|
47 |
+
- **Main Configuration**: main config file specifies default training configuration | [#Main Project Configuration](#main-project-configuration)
|
48 |
+
- **Experiment Configurations**: can be composed out of smaller configs and override chosen hyperparameters | [#Experiment Configuration](#experiment-configuration)
|
49 |
+
- **Workflow**: comes down to 4 simple steps | [#Workflow](#workflow)
|
50 |
+
- **Experiment Tracking**: many logging frameworks can be easily integrated, like Tensorboard, MLFlow or W&B | [#Experiment Tracking](#experiment-tracking)
|
51 |
+
- **Logs**: all logs (checkpoints, data from loggers, hparams, etc.) are stored in a convenient folder structure imposed by Hydra | [#Logs](#logs)
|
52 |
+
- **Hyperparameter Search**: made easier with Hydra built-in plugins like [Optuna Sweeper](https://hydra.cc/docs/next/plugins/optuna_sweeper) | [#Hyperparameter Search](#hyperparameter-search)
|
53 |
+
- **Tests**: unit tests and shell/command based tests for speeding up the development | [#Tests](#tests)
|
54 |
+
- **Best Practices**: a couple of recommended tools, practices and standards for efficient workflow and reproducibility | [#Best Practices](#best-practices)
|
55 |
+
|
56 |
+
<br>
|
57 |
+
|
58 |
+
## Project Structure
|
59 |
+
|
60 |
+
The directory structure of new project looks like this:
|
61 |
+
|
62 |
+
```
|
63 |
+
├── configs <- Hydra configuration files
|
64 |
+
│ ├── callbacks <- Callbacks configs
|
65 |
+
│ ├── datamodule <- Datamodule configs
|
66 |
+
│ ├── debug <- Debugging configs
|
67 |
+
│ ├── experiment <- Experiment configs
|
68 |
+
│ ├── hparams_search <- Hyperparameter search configs
|
69 |
+
│ ├── local <- Local configs
|
70 |
+
│ ├── log_dir <- Logging directory configs
|
71 |
+
│ ├── logger <- Logger configs
|
72 |
+
│ ├── model <- Model configs
|
73 |
+
│ ├── trainer <- Trainer configs
|
74 |
+
│ │
|
75 |
+
│ ├── test.yaml <- Main config for testing
|
76 |
+
│ └── train.yaml <- Main config for training
|
77 |
+
│
|
78 |
+
├── data <- Project data
|
79 |
+
│
|
80 |
+
├── logs <- Logs generated by Hydra and PyTorch Lightning loggers
|
81 |
+
│
|
82 |
+
├── notebooks <- Jupyter notebooks. Naming convention is a number (for ordering),
|
83 |
+
│ the creator's initials, and a short `-` delimited description,
|
84 |
+
│ e.g. `1.0-jqp-initial-data-exploration.ipynb`.
|
85 |
+
│
|
86 |
+
├── scripts <- Shell scripts
|
87 |
+
│
|
88 |
+
├── src <- Source code
|
89 |
+
│ ├── datamodules <- Lightning datamodules
|
90 |
+
│ ├── models <- Lightning models
|
91 |
+
│ ├── utils <- Utility scripts
|
92 |
+
│ ├── vendor <- Third party code that cannot be installed using PIP/Conda
|
93 |
+
│ │
|
94 |
+
│ ├── testing_pipeline.py
|
95 |
+
│ └── training_pipeline.py
|
96 |
+
│
|
97 |
+
├── tests <- Tests of any kind
|
98 |
+
│ ├── helpers <- A couple of testing utilities
|
99 |
+
│ ├── shell <- Shell/command based tests
|
100 |
+
│ └── unit <- Unit tests
|
101 |
+
│
|
102 |
+
├── test.py <- Run testing
|
103 |
+
├── train.py <- Run training
|
104 |
+
│
|
105 |
+
├── .env.example <- Template of the file for storing private environment variables
|
106 |
+
├── .gitignore <- List of files/folders ignored by git
|
107 |
+
├── .pre-commit-config.yaml <- Configuration of pre-commit hooks for code formatting
|
108 |
+
├── requirements.txt <- File for installing python dependencies
|
109 |
+
├── setup.cfg <- Configuration of linters and pytest
|
110 |
+
└── README.md
|
111 |
+
```
|
112 |
+
|
113 |
+
<br>
|
114 |
+
|
115 |
+
## 🚀 Quickstart
|
116 |
+
|
117 |
+
```bash
|
118 |
+
# clone project
|
119 |
+
git clone https://github.com/ashleve/lightning-hydra-template
|
120 |
+
cd lightning-hydra-template
|
121 |
+
|
122 |
+
# [OPTIONAL] create conda environment
|
123 |
+
conda create -n myenv python=3.8
|
124 |
+
conda activate myenv
|
125 |
+
|
126 |
+
# install pytorch according to instructions
|
127 |
+
# https://pytorch.org/get-started/
|
128 |
+
|
129 |
+
# install requirements
|
130 |
+
pip install -r requirements.txt
|
131 |
+
```
|
132 |
+
|
133 |
+
Template contains example with MNIST classification.<br>
|
134 |
+
When running `python train.py` you should see something like this:
|
135 |
+
|
136 |
+
<div align="center">
|
137 |
+
|
138 |
+
![](https://github.com/ashleve/lightning-hydra-template/blob/resources/terminal.png)
|
139 |
+
|
140 |
+
</div>
|
141 |
+
|
142 |
+
### ⚡ Your Superpowers
|
143 |
+
|
144 |
+
<details>
|
145 |
+
<summary><b>Override any config parameter from command line</b></summary>
|
146 |
+
|
147 |
+
> Hydra allows you to easily overwrite any parameter defined in your config.
|
148 |
+
|
149 |
+
```bash
|
150 |
+
python train.py trainer.max_epochs=20 model.lr=1e-4
|
151 |
+
```
|
152 |
+
|
153 |
+
> You can also add new parameters with `+` sign.
|
154 |
+
|
155 |
+
```bash
|
156 |
+
python train.py +model.new_param="uwu"
|
157 |
+
```
|
158 |
+
|
159 |
+
</details>
|
160 |
+
|
161 |
+
<details>
|
162 |
+
<summary><b>Train on CPU, GPU, multi-GPU and TPU</b></summary>
|
163 |
+
|
164 |
+
> PyTorch Lightning makes it easy to train your models on different hardware.
|
165 |
+
|
166 |
+
```bash
|
167 |
+
# train on CPU
|
168 |
+
python train.py trainer.gpus=0
|
169 |
+
|
170 |
+
# train on 1 GPU
|
171 |
+
python train.py trainer.gpus=1
|
172 |
+
|
173 |
+
# train on TPU
|
174 |
+
python train.py +trainer.tpu_cores=8
|
175 |
+
|
176 |
+
# train with DDP (Distributed Data Parallel) (4 GPUs)
|
177 |
+
python train.py trainer.gpus=4 +trainer.strategy=ddp
|
178 |
+
|
179 |
+
# train with DDP (Distributed Data Parallel) (8 GPUs, 2 nodes)
|
180 |
+
python train.py trainer.gpus=4 +trainer.num_nodes=2 +trainer.strategy=ddp
|
181 |
+
```
|
182 |
+
|
183 |
+
</details>
|
184 |
+
|
185 |
+
<details>
|
186 |
+
<summary><b>Train with mixed precision</b></summary>
|
187 |
+
|
188 |
+
```bash
|
189 |
+
# train with pytorch native automatic mixed precision (AMP)
|
190 |
+
python train.py trainer.gpus=1 +trainer.precision=16
|
191 |
+
```
|
192 |
+
|
193 |
+
</details>
|
194 |
+
|
195 |
+
<!-- deepspeed support still in beta
|
196 |
+
<details>
|
197 |
+
<summary><b>Optimize large scale models on multiple GPUs with Deepspeed</b></summary>
|
198 |
+
|
199 |
+
```bash
|
200 |
+
python train.py +trainer.
|
201 |
+
```
|
202 |
+
|
203 |
+
</details>
|
204 |
+
-->
|
205 |
+
|
206 |
+
<details>
|
207 |
+
<summary><b>Train model with any logger available in PyTorch Lightning, like Weights&Biases or Tensorboard</b></summary>
|
208 |
+
|
209 |
+
> PyTorch Lightning provides convenient integrations with most popular logging frameworks, like Tensorboard, Neptune or simple csv files. Read more [here](#experiment-tracking). Using wandb requires you to [setup account](https://www.wandb.com/) first. After that just complete the config as below.<br> > **Click [here](https://wandb.ai/hobglob/template-dashboard/) to see example wandb dashboard generated with this template.**
|
210 |
+
|
211 |
+
```bash
|
212 |
+
# set project and entity names in `configs/logger/wandb`
|
213 |
+
wandb:
|
214 |
+
project: "your_project_name"
|
215 |
+
entity: "your_wandb_team_name"
|
216 |
+
```
|
217 |
+
|
218 |
+
```bash
|
219 |
+
# train model with Weights&Biases (link to wandb dashboard should appear in the terminal)
|
220 |
+
python train.py logger=wandb
|
221 |
+
```
|
222 |
+
|
223 |
+
</details>
|
224 |
+
|
225 |
+
<details>
|
226 |
+
<summary><b>Train model with chosen experiment config</b></summary>
|
227 |
+
|
228 |
+
> Experiment configurations are placed in [configs/experiment/](configs/experiment/).
|
229 |
+
|
230 |
+
```bash
|
231 |
+
python train.py experiment=example
|
232 |
+
```
|
233 |
+
|
234 |
+
</details>
|
235 |
+
|
236 |
+
<details>
|
237 |
+
<summary><b>Attach some callbacks to run</b></summary>
|
238 |
+
|
239 |
+
> Callbacks can be used for things such as as model checkpointing, early stopping and [many more](https://pytorch-lightning.readthedocs.io/en/latest/extensions/callbacks.html#built-in-callbacks).<br>
|
240 |
+
> Callbacks configurations are placed in [configs/callbacks/](configs/callbacks/).
|
241 |
+
|
242 |
+
```bash
|
243 |
+
python train.py callbacks=default
|
244 |
+
```
|
245 |
+
|
246 |
+
</details>
|
247 |
+
|
248 |
+
<details>
|
249 |
+
<summary><b>Use different tricks available in Pytorch Lightning</b></summary>
|
250 |
+
|
251 |
+
> PyTorch Lightning provides about [40+ useful trainer flags](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags).
|
252 |
+
|
253 |
+
```yaml
|
254 |
+
# gradient clipping may be enabled to avoid exploding gradients
|
255 |
+
python train.py +trainer.gradient_clip_val=0.5
|
256 |
+
|
257 |
+
# stochastic weight averaging can make your models generalize better
|
258 |
+
python train.py +trainer.stochastic_weight_avg=true
|
259 |
+
|
260 |
+
# run validation loop 4 times during a training epoch
|
261 |
+
python train.py +trainer.val_check_interval=0.25
|
262 |
+
|
263 |
+
# accumulate gradients
|
264 |
+
python train.py +trainer.accumulate_grad_batches=10
|
265 |
+
|
266 |
+
# terminate training after 12 hours
|
267 |
+
python train.py +trainer.max_time="00:12:00:00"
|
268 |
+
```
|
269 |
+
|
270 |
+
</details>
|
271 |
+
|
272 |
+
<details>
|
273 |
+
<summary><b>Easily debug</b></summary>
|
274 |
+
|
275 |
+
> Visit [configs/debug/](configs/debug/) for different debugging configs.
|
276 |
+
|
277 |
+
```bash
|
278 |
+
# runs 1 epoch in default debugging mode
|
279 |
+
# changes logging directory to `logs/debugs/...`
|
280 |
+
# sets level of all command line loggers to 'DEBUG'
|
281 |
+
# enables extra trainer flags like tracking gradient norm
|
282 |
+
# enforces debug-friendly configuration
|
283 |
+
python train.py debug=default
|
284 |
+
|
285 |
+
# runs test epoch without training
|
286 |
+
python train.py debug=test_only
|
287 |
+
|
288 |
+
# run 1 train, val and test loop, using only 1 batch
|
289 |
+
python train.py +trainer.fast_dev_run=true
|
290 |
+
|
291 |
+
# raise exception if there are any numerical anomalies in tensors, like NaN or +/-inf
|
292 |
+
python train.py +trainer.detect_anomaly=true
|
293 |
+
|
294 |
+
# print execution time profiling after training ends
|
295 |
+
python train.py +trainer.profiler="simple"
|
296 |
+
|
297 |
+
# try overfitting to 1 batch
|
298 |
+
python train.py +trainer.overfit_batches=1 trainer.max_epochs=20
|
299 |
+
|
300 |
+
# use only 20% of the data
|
301 |
+
python train.py +trainer.limit_train_batches=0.2 \
|
302 |
+
+trainer.limit_val_batches=0.2 +trainer.limit_test_batches=0.2
|
303 |
+
|
304 |
+
# log second gradient norm of the model
|
305 |
+
python train.py +trainer.track_grad_norm=2
|
306 |
+
```
|
307 |
+
|
308 |
+
</details>
|
309 |
+
|
310 |
+
<details>
|
311 |
+
<summary><b>Resume training from checkpoint</b></summary>
|
312 |
+
|
313 |
+
> Checkpoint can be either path or URL.
|
314 |
+
|
315 |
+
```yaml
|
316 |
+
python train.py trainer.resume_from_checkpoint="/path/to/ckpt/name.ckpt"
|
317 |
+
```
|
318 |
+
|
319 |
+
> ⚠️ Currently loading ckpt in Lightning doesn't resume logger experiment, but it will be supported in future Lightning release.
|
320 |
+
|
321 |
+
</details>
|
322 |
+
|
323 |
+
<details>
|
324 |
+
<summary><b>Execute evaluation for a given checkpoint</b></summary>
|
325 |
+
|
326 |
+
> Checkpoint can be either path or URL.
|
327 |
+
|
328 |
+
```yaml
|
329 |
+
python test.py ckpt_path="/path/to/ckpt/name.ckpt"
|
330 |
+
```
|
331 |
+
|
332 |
+
</details>
|
333 |
+
|
334 |
+
<details>
|
335 |
+
<summary><b>Create a sweep over hyperparameters</b></summary>
|
336 |
+
|
337 |
+
```bash
|
338 |
+
# this will run 6 experiments one after the other,
|
339 |
+
# each with different combination of batch_size and learning rate
|
340 |
+
python train.py -m datamodule.batch_size=32,64,128 model.lr=0.001,0.0005
|
341 |
+
```
|
342 |
+
|
343 |
+
> ⚠️ This sweep is not failure resistant (if one job crashes than the whole sweep crashes).
|
344 |
+
|
345 |
+
</details>
|
346 |
+
|
347 |
+
<details>
|
348 |
+
<summary><b>Create a sweep over hyperparameters with Optuna</b></summary>
|
349 |
+
|
350 |
+
> Using [Optuna Sweeper](https://hydra.cc/docs/next/plugins/optuna_sweeper) plugin doesn't require you to code any boilerplate into your pipeline, everything is defined in a [single config file](configs/hparams_search/mnist_optuna.yaml)!
|
351 |
+
|
352 |
+
```bash
|
353 |
+
# this will run hyperparameter search defined in `configs/hparams_search/mnist_optuna.yaml`
|
354 |
+
# over chosen experiment config
|
355 |
+
python train.py -m hparams_search=mnist_optuna experiment=example_simple
|
356 |
+
```
|
357 |
+
|
358 |
+
> ⚠️ Currently this sweep is not failure resistant (if one job crashes than the whole sweep crashes). Might be supported in future Hydra release.
|
359 |
+
|
360 |
+
</details>
|
361 |
+
|
362 |
+
<details>
|
363 |
+
<summary><b>Execute all experiments from folder</b></summary>
|
364 |
+
|
365 |
+
> Hydra provides special syntax for controlling behavior of multiruns. Learn more [here](https://hydra.cc/docs/next/tutorials/basic/running_your_app/multi-run). The command below executes all experiments from folder [configs/experiment/](configs/experiment/).
|
366 |
+
|
367 |
+
```bash
|
368 |
+
python train.py -m 'experiment=glob(*)'
|
369 |
+
```
|
370 |
+
|
371 |
+
</details>
|
372 |
+
|
373 |
+
<details>
|
374 |
+
<summary><b>Execute sweep on a remote AWS cluster</b></summary>
|
375 |
+
|
376 |
+
> This should be achievable with simple config using [Ray AWS launcher for Hydra](https://hydra.cc/docs/next/plugins/ray_launcher). Example is not yet implemented in this template.
|
377 |
+
|
378 |
+
</details>
|
379 |
+
|
380 |
+
<!-- <details>
|
381 |
+
<summary><b>Execute sweep on a SLURM cluster</b></summary>
|
382 |
+
|
383 |
+
> This should be achievable with either [the right lightning trainer flags](https://pytorch-lightning.readthedocs.io/en/latest/clouds/cluster.html?highlight=SLURM#slurm-managed-cluster) or simple config using [Submitit launcher for Hydra](https://hydra.cc/docs/plugins/submitit_launcher). Example is not yet implemented in this template.
|
384 |
+
|
385 |
+
</details> -->
|
386 |
+
|
387 |
+
<details>
|
388 |
+
<summary><b>Use Hydra tab completion</b></summary>
|
389 |
+
|
390 |
+
> Hydra allows you to autocomplete config argument overrides in shell as you write them, by pressing `tab` key. Learn more [here](https://hydra.cc/docs/tutorials/basic/running_your_app/tab_completion).
|
391 |
+
|
392 |
+
</details>
|
393 |
+
|
394 |
+
<details>
|
395 |
+
<summary><b>Apply pre-commit hooks</b></summary>
|
396 |
+
|
397 |
+
> Apply pre-commit hooks to automatically format your code and configs, perform code analysis and remove output from jupyter notebooks. See [# Best Practices](#best-practices) for more.
|
398 |
+
|
399 |
+
```bash
|
400 |
+
pre-commit run -a
|
401 |
+
```
|
402 |
+
|
403 |
+
</details>
|
404 |
+
|
405 |
+
<br>
|
406 |
+
|
407 |
+
## ❤️ Contributions
|
408 |
+
|
409 |
+
Have a question? Found a bug? Missing a specific feature? Have an idea for improving documentation? Feel free to file a new issue, discussion or PR with respective title and description. If you already found a solution to your problem, don't hesitate to share it. Suggestions for new best practices are always welcome!
|
410 |
+
|
411 |
+
<br>
|
412 |
+
|
413 |
+
## ℹ️ Guide
|
414 |
+
|
415 |
+
### How To Get Started
|
416 |
+
|
417 |
+
- First, you should probably get familiar with [PyTorch Lightning](https://www.pytorchlightning.ai)
|
418 |
+
- Next, go through [Hydra quick start guide](https://hydra.cc/docs/intro/) and [basic Hydra tutorial](https://hydra.cc/docs/tutorials/basic/your_first_app/simple_cli/)
|
419 |
+
|
420 |
+
<br>
|
421 |
+
|
422 |
+
### How It Works
|
423 |
+
|
424 |
+
All PyTorch Lightning modules are dynamically instantiated from module paths specified in config. Example model config:
|
425 |
+
|
426 |
+
```yaml
|
427 |
+
_target_: src.models.mnist_model.MNISTLitModule
|
428 |
+
input_size: 784
|
429 |
+
lin1_size: 256
|
430 |
+
lin2_size: 256
|
431 |
+
lin3_size: 256
|
432 |
+
output_size: 10
|
433 |
+
lr: 0.001
|
434 |
+
```
|
435 |
+
|
436 |
+
Using this config we can instantiate the object with the following line:
|
437 |
+
|
438 |
+
```python
|
439 |
+
model = hydra.utils.instantiate(config.model)
|
440 |
+
```
|
441 |
+
|
442 |
+
This allows you to easily iterate over new models! Every time you create a new one, just specify its module path and parameters in appropriate config file. <br>
|
443 |
+
|
444 |
+
Switch between models and datamodules with command line arguments:
|
445 |
+
|
446 |
+
```bash
|
447 |
+
python train.py model=mnist
|
448 |
+
```
|
449 |
+
|
450 |
+
The whole pipeline managing the instantiation logic is placed in [src/training_pipeline.py](src/training_pipeline.py).
|
451 |
+
|
452 |
+
<br>
|
453 |
+
|
454 |
+
### Main Project Configuration
|
455 |
+
|
456 |
+
Location: [configs/train.yaml](configs/train.yaml) <br>
|
457 |
+
Main project config contains default training configuration.<br>
|
458 |
+
It determines how config is composed when simply executing command `python train.py`.<br>
|
459 |
+
|
460 |
+
<details>
|
461 |
+
<summary><b>Show main project config</b></summary>
|
462 |
+
|
463 |
+
```yaml
|
464 |
+
# specify here default training configuration
|
465 |
+
defaults:
|
466 |
+
- _self_
|
467 |
+
- datamodule: mnist.yaml
|
468 |
+
- model: mnist.yaml
|
469 |
+
- callbacks: default.yaml
|
470 |
+
- logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
|
471 |
+
- trainer: default.yaml
|
472 |
+
- log_dir: default.yaml
|
473 |
+
|
474 |
+
# experiment configs allow for version control of specific configurations
|
475 |
+
# e.g. best hyperparameters for each combination of model and datamodule
|
476 |
+
- experiment: null
|
477 |
+
|
478 |
+
# debugging config (enable through command line, e.g. `python train.py debug=default)
|
479 |
+
- debug: null
|
480 |
+
|
481 |
+
# config for hyperparameter optimization
|
482 |
+
- hparams_search: null
|
483 |
+
|
484 |
+
# optional local config for machine/user specific settings
|
485 |
+
# it's optional since it doesn't need to exist and is excluded from version control
|
486 |
+
- optional local: default.yaml
|
487 |
+
|
488 |
+
# enable color logging
|
489 |
+
- override hydra/hydra_logging: colorlog
|
490 |
+
- override hydra/job_logging: colorlog
|
491 |
+
|
492 |
+
# path to original working directory
|
493 |
+
# hydra hijacks working directory by changing it to the new log directory
|
494 |
+
# https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory
|
495 |
+
original_work_dir: ${hydra:runtime.cwd}
|
496 |
+
|
497 |
+
# path to folder with data
|
498 |
+
data_dir: ${original_work_dir}/data/
|
499 |
+
|
500 |
+
# pretty print config at the start of the run using Rich library
|
501 |
+
print_config: True
|
502 |
+
|
503 |
+
# disable python warnings if they annoy you
|
504 |
+
ignore_warnings: True
|
505 |
+
|
506 |
+
# set False to skip model training
|
507 |
+
train: True
|
508 |
+
|
509 |
+
# evaluate on test set, using best model weights achieved during training
|
510 |
+
# lightning chooses best weights based on the metric specified in checkpoint callback
|
511 |
+
test: True
|
512 |
+
|
513 |
+
# seed for random number generators in pytorch, numpy and python.random
|
514 |
+
seed: null
|
515 |
+
|
516 |
+
# default name for the experiment, determines logging folder path
|
517 |
+
# (you can overwrite this name in experiment configs)
|
518 |
+
name: "default"
|
519 |
+
```
|
520 |
+
|
521 |
+
</details>
|
522 |
+
|
523 |
+
<br>
|
524 |
+
|
525 |
+
### Experiment Configuration
|
526 |
+
|
527 |
+
Location: [configs/experiment](configs/experiment)<br>
|
528 |
+
Experiment configs allow you to overwrite parameters from main project configuration.<br>
|
529 |
+
For example, you can use them to version control best hyperparameters for each combination of model and dataset.
|
530 |
+
|
531 |
+
<details>
|
532 |
+
<summary><b>Show example experiment config</b></summary>
|
533 |
+
|
534 |
+
```yaml
|
535 |
+
# to execute this experiment run:
|
536 |
+
# python train.py experiment=example
|
537 |
+
|
538 |
+
defaults:
|
539 |
+
- override /datamodule: mnist.yaml
|
540 |
+
- override /model: mnist.yaml
|
541 |
+
- override /callbacks: default.yaml
|
542 |
+
- override /logger: null
|
543 |
+
- override /trainer: default.yaml
|
544 |
+
|
545 |
+
# all parameters below will be merged with parameters from default configurations set above
|
546 |
+
# this allows you to overwrite only specified parameters
|
547 |
+
|
548 |
+
# name of the run determines folder name in logs
|
549 |
+
name: "simple_dense_net"
|
550 |
+
|
551 |
+
seed: 12345
|
552 |
+
|
553 |
+
trainer:
|
554 |
+
min_epochs: 10
|
555 |
+
max_epochs: 10
|
556 |
+
gradient_clip_val: 0.5
|
557 |
+
|
558 |
+
model:
|
559 |
+
lin1_size: 128
|
560 |
+
lin2_size: 256
|
561 |
+
lin3_size: 64
|
562 |
+
lr: 0.002
|
563 |
+
|
564 |
+
datamodule:
|
565 |
+
batch_size: 64
|
566 |
+
|
567 |
+
logger:
|
568 |
+
wandb:
|
569 |
+
tags: ["mnist", "${name}"]
|
570 |
+
```
|
571 |
+
|
572 |
+
</details>
|
573 |
+
|
574 |
+
<br>
|
575 |
+
|
576 |
+
### Local Configuration
|
577 |
+
|
578 |
+
Location: [configs/local](configs/local) <br>
|
579 |
+
Some configurations are user/machine/installation specific (e.g. configuration of local cluster, or harddrive paths on a specific machine). For such scenarios, a file `configs/local/default.yaml` can be created which is automatically loaded but not tracked by Git.
|
580 |
+
|
581 |
+
<details>
|
582 |
+
<summary><b>Show example local Slurm cluster config</b></summary>
|
583 |
+
|
584 |
+
```yaml
|
585 |
+
# @package _global_
|
586 |
+
|
587 |
+
defaults:
|
588 |
+
- override /hydra/launcher@_here_: submitit_slurm
|
589 |
+
|
590 |
+
data_dir: /mnt/scratch/data/
|
591 |
+
|
592 |
+
hydra:
|
593 |
+
launcher:
|
594 |
+
timeout_min: 1440
|
595 |
+
gpus_per_task: 1
|
596 |
+
gres: gpu:1
|
597 |
+
job:
|
598 |
+
env_set:
|
599 |
+
MY_VAR: /home/user/my/system/path
|
600 |
+
MY_KEY: asdgjhawi8y23ihsghsueity23ihwd
|
601 |
+
```
|
602 |
+
|
603 |
+
</details>
|
604 |
+
|
605 |
+
<br>
|
606 |
+
|
607 |
+
### Workflow
|
608 |
+
|
609 |
+
1. Write your PyTorch Lightning module (see [models/mnist_module.py](src/models/mnist_module.py) for example)
|
610 |
+
2. Write your PyTorch Lightning datamodule (see [datamodules/mnist_datamodule.py](src/datamodules/mnist_datamodule.py) for example)
|
611 |
+
3. Write your experiment config, containing paths to your model and datamodule
|
612 |
+
4. Run training with chosen experiment config: `python train.py experiment=experiment_name`
|
613 |
+
|
614 |
+
<br>
|
615 |
+
|
616 |
+
### Logs
|
617 |
+
|
618 |
+
**Hydra creates new working directory for every executed run.** By default, logs have the following structure:
|
619 |
+
|
620 |
+
```
|
621 |
+
├── logs
|
622 |
+
│ ├── experiments # Folder for the logs generated by experiments
|
623 |
+
│ │ ├── runs # Folder for single runs
|
624 |
+
│ │ │ ├── experiment_name # Experiment name
|
625 |
+
│ │ │ │ ├── YYYY-MM-DD_HH-MM-SS # Datetime of the run
|
626 |
+
│ │ │ │ │ ├── .hydra # Hydra logs
|
627 |
+
│ │ │ │ │ ├── csv # Csv logs
|
628 |
+
│ │ │ │ │ ├── wandb # Weights&Biases logs
|
629 |
+
│ │ │ │ │ ├── checkpoints # Training checkpoints
|
630 |
+
│ │ │ │ │ └── ... # Any other thing saved during training
|
631 |
+
│ │ │ │ └── ...
|
632 |
+
│ │ │ └── ...
|
633 |
+
│ │ │
|
634 |
+
│ │ └── multiruns # Folder for multiruns
|
635 |
+
│ │ ├── experiment_name # Experiment name
|
636 |
+
│ │ │ ├── YYYY-MM-DD_HH-MM-SS # Datetime of the multirun
|
637 |
+
│ │ │ │ ├──1 # Multirun job number
|
638 |
+
│ │ │ │ ├──2
|
639 |
+
│ │ │ │ └── ...
|
640 |
+
│ │ │ └── ...
|
641 |
+
│ │ └── ...
|
642 |
+
│ │
|
643 |
+
│ ├── evaluations # Folder for the logs generated during testing
|
644 |
+
│ │ └── ...
|
645 |
+
│ │
|
646 |
+
│ └── debugs # Folder for the logs generated during debugging
|
647 |
+
│ └── ...
|
648 |
+
```
|
649 |
+
|
650 |
+
You can change this structure by modifying paths in [hydra configuration](configs/log_dir).
|
651 |
+
|
652 |
+
<br>
|
653 |
+
|
654 |
+
### Experiment Tracking
|
655 |
+
|
656 |
+
PyTorch Lightning supports many popular logging frameworks:<br>
|
657 |
+
**[Weights&Biases](https://www.wandb.com/) · [Neptune](https://neptune.ai/) · [Comet](https://www.comet.ml/) · [MLFlow](https://mlflow.org) · [Tensorboard](https://www.tensorflow.org/tensorboard/)**
|
658 |
+
|
659 |
+
These tools help you keep track of hyperparameters and output metrics and allow you to compare and visualize results. To use one of them simply complete its configuration in [configs/logger](configs/logger) and run:
|
660 |
+
|
661 |
+
```bash
|
662 |
+
python train.py logger=logger_name
|
663 |
+
```
|
664 |
+
|
665 |
+
You can use many of them at once (see [configs/logger/many_loggers.yaml](configs/logger/many_loggers.yaml) for example).
|
666 |
+
|
667 |
+
You can also write your own logger.
|
668 |
+
|
669 |
+
Lightning provides convenient method for logging custom metrics from inside LightningModule. Read the docs [here](https://pytorch-lightning.readthedocs.io/en/latest/extensions/logging.html#automatic-logging) or take a look at [MNIST example](src/models/mnist_module.py).
|
670 |
+
|
671 |
+
<br>
|
672 |
+
|
673 |
+
### Hyperparameter Search
|
674 |
+
|
675 |
+
Defining hyperparameter optimization is as easy as adding new config file to [configs/hparams_search](configs/hparams_search).
|
676 |
+
|
677 |
+
<details>
|
678 |
+
<summary><b>Show example</b></summary>
|
679 |
+
|
680 |
+
```yaml
|
681 |
+
defaults:
|
682 |
+
- override /hydra/sweeper: optuna
|
683 |
+
|
684 |
+
# choose metric which will be optimized by Optuna
|
685 |
+
optimized_metric: "val/acc_best"
|
686 |
+
|
687 |
+
hydra:
|
688 |
+
# here we define Optuna hyperparameter search
|
689 |
+
# it optimizes for value returned from function with @hydra.main decorator
|
690 |
+
# learn more here: https://hydra.cc/docs/next/plugins/optuna_sweeper
|
691 |
+
sweeper:
|
692 |
+
_target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
|
693 |
+
storage: null
|
694 |
+
study_name: null
|
695 |
+
n_jobs: 1
|
696 |
+
|
697 |
+
# 'minimize' or 'maximize' the objective
|
698 |
+
direction: maximize
|
699 |
+
|
700 |
+
# number of experiments that will be executed
|
701 |
+
n_trials: 20
|
702 |
+
|
703 |
+
# choose Optuna hyperparameter sampler
|
704 |
+
# learn more here: https://optuna.readthedocs.io/en/stable/reference/samplers.html
|
705 |
+
sampler:
|
706 |
+
_target_: optuna.samplers.TPESampler
|
707 |
+
seed: 12345
|
708 |
+
consider_prior: true
|
709 |
+
prior_weight: 1.0
|
710 |
+
consider_magic_clip: true
|
711 |
+
consider_endpoints: false
|
712 |
+
n_startup_trials: 10
|
713 |
+
n_ei_candidates: 24
|
714 |
+
multivariate: false
|
715 |
+
warn_independent_sampling: true
|
716 |
+
|
717 |
+
# define range of hyperparameters
|
718 |
+
search_space:
|
719 |
+
datamodule.batch_size:
|
720 |
+
type: categorical
|
721 |
+
choices: [32, 64, 128]
|
722 |
+
model.lr:
|
723 |
+
type: float
|
724 |
+
low: 0.0001
|
725 |
+
high: 0.2
|
726 |
+
model.lin1_size:
|
727 |
+
type: categorical
|
728 |
+
choices: [32, 64, 128, 256, 512]
|
729 |
+
model.lin2_size:
|
730 |
+
type: categorical
|
731 |
+
choices: [32, 64, 128, 256, 512]
|
732 |
+
model.lin3_size:
|
733 |
+
type: categorical
|
734 |
+
choices: [32, 64, 128, 256, 512]
|
735 |
+
```
|
736 |
+
|
737 |
+
</details>
|
738 |
+
|
739 |
+
Next, you can execute it with: `python train.py -m hparams_search=mnist_optuna`
|
740 |
+
|
741 |
+
Using this approach doesn't require you to add any boilerplate into your pipeline, everything is defined in a single config file.
|
742 |
+
|
743 |
+
You can use different optimization frameworks integrated with Hydra, like Optuna, Ax or Nevergrad.
|
744 |
+
|
745 |
+
The `optimization_results.yaml` will be available under `logs/multirun` folder.
|
746 |
+
|
747 |
+
This approach doesn't support advanced technics like prunning - for more sophisticated search, you probably shouldn't use hydra multirun feature and instead write your own optimization pipeline.
|
748 |
+
|
749 |
+
<br>
|
750 |
+
|
751 |
+
### Inference
|
752 |
+
|
753 |
+
The following code is an example of loading model from checkpoint and running predictions.<br>
|
754 |
+
|
755 |
+
<details>
|
756 |
+
<summary><b>Show example</b></summary>
|
757 |
+
|
758 |
+
```python
|
759 |
+
from PIL import Image
|
760 |
+
from torchvision import transforms
|
761 |
+
|
762 |
+
from src.models.mnist_module import MNISTLitModule
|
763 |
+
|
764 |
+
|
765 |
+
def predict():
|
766 |
+
"""Example of inference with trained model.
|
767 |
+
It loads trained image classification model from checkpoint.
|
768 |
+
Then it loads example image and predicts its label.
|
769 |
+
"""
|
770 |
+
|
771 |
+
# ckpt can be also a URL!
|
772 |
+
CKPT_PATH = "last.ckpt"
|
773 |
+
|
774 |
+
# load model from checkpoint
|
775 |
+
# model __init__ parameters will be loaded from ckpt automatically
|
776 |
+
# you can also pass some parameter explicitly to override it
|
777 |
+
trained_model = MNISTLitModule.load_from_checkpoint(checkpoint_path=CKPT_PATH)
|
778 |
+
|
779 |
+
# print model hyperparameters
|
780 |
+
print(trained_model.hparams)
|
781 |
+
|
782 |
+
# switch to evaluation mode
|
783 |
+
trained_model.eval()
|
784 |
+
trained_model.freeze()
|
785 |
+
|
786 |
+
# load data
|
787 |
+
img = Image.open("data/example_img.png").convert("L") # convert to black and white
|
788 |
+
# img = Image.open("data/example_img.png").convert("RGB") # convert to RGB
|
789 |
+
|
790 |
+
# preprocess
|
791 |
+
mnist_transforms = transforms.Compose(
|
792 |
+
[
|
793 |
+
transforms.ToTensor(),
|
794 |
+
transforms.Resize((28, 28)),
|
795 |
+
transforms.Normalize((0.1307,), (0.3081,)),
|
796 |
+
]
|
797 |
+
)
|
798 |
+
img = mnist_transforms(img)
|
799 |
+
img = img.reshape((1, *img.size())) # reshape to form batch of size 1
|
800 |
+
|
801 |
+
# inference
|
802 |
+
output = trained_model(img)
|
803 |
+
print(output)
|
804 |
+
|
805 |
+
|
806 |
+
if __name__ == "__main__":
|
807 |
+
predict()
|
808 |
+
|
809 |
+
```
|
810 |
+
|
811 |
+
</details>
|
812 |
+
|
813 |
+
<br>
|
814 |
+
|
815 |
+
### Tests
|
816 |
+
|
817 |
+
Template comes with example tests implemented with pytest library. To execute them simply run:
|
818 |
+
|
819 |
+
```bash
|
820 |
+
# run all tests
|
821 |
+
pytest
|
822 |
+
|
823 |
+
# run tests from specific file
|
824 |
+
pytest tests/shell/test_basic_commands.py
|
825 |
+
|
826 |
+
# run all tests except the ones marked as slow
|
827 |
+
pytest -k "not slow"
|
828 |
+
```
|
829 |
+
|
830 |
+
To speed up the development, you can once in a while execute tests that run a couple of quick experiments, like training 1 epoch on 25% of data, executing single train/val/test step, etc. Those kind of tests don't check for any specific output - they exist to simply verify that executing some bash commands doesn't end up in throwing exceptions. You can find them implemented in [tests/shell](tests/shell) folder.
|
831 |
+
|
832 |
+
You can easily modify the commands in the scripts for your use case. If 1 epoch is too much for your model, then make it run for a couple of batches instead (by using the right trainer flags).
|
833 |
+
|
834 |
+
<br>
|
835 |
+
|
836 |
+
### Callbacks
|
837 |
+
|
838 |
+
The branch [`wandb-callbacks`](https://github.com/ashleve/lightning-hydra-template/tree/wandb-callbacks) contains example callbacks enabling better Weights&Biases integration, which you can use as a reference for writing your own callbacks (see [wandb_callbacks.py](https://github.com/ashleve/lightning-hydra-template/tree/wandb-callbacks/src/callbacks/wandb_callbacks.py)).
|
839 |
+
|
840 |
+
Callbacks which support reproducibility:
|
841 |
+
|
842 |
+
- **WatchModel**
|
843 |
+
- **UploadCodeAsArtifact**
|
844 |
+
- **UploadCheckpointsAsArtifact**
|
845 |
+
|
846 |
+
Callbacks which provide examples of logging custom visualisations:
|
847 |
+
|
848 |
+
- **LogConfusionMatrix**
|
849 |
+
- **LogF1PrecRecHeatmap**
|
850 |
+
- **LogImagePredictions**
|
851 |
+
|
852 |
+
To try all of the callbacks at once, switch to the right branch:
|
853 |
+
|
854 |
+
```bash
|
855 |
+
git checkout wandb-callbacks
|
856 |
+
```
|
857 |
+
|
858 |
+
And then run the following command:
|
859 |
+
|
860 |
+
```bash
|
861 |
+
python train.py logger=wandb callbacks=wandb
|
862 |
+
```
|
863 |
+
|
864 |
+
To see the result of all the callbacks attached, take a look at [this experiment dashboard](https://wandb.ai/hobglob/template-tests/runs/3rw7q70h).
|
865 |
+
|
866 |
+
<br>
|
867 |
+
|
868 |
+
### Multi-GPU Training
|
869 |
+
|
870 |
+
Lightning supports multiple ways of doing distributed training. The most common one is DDP, which spawns separate process for each GPU and averages gradients between them. To learn about other approaches read the [lightning docs](https://pytorch-lightning.readthedocs.io/en/latest/advanced/multi_gpu.html).
|
871 |
+
|
872 |
+
You can run DDP on mnist example with 4 GPUs like this:
|
873 |
+
|
874 |
+
```bash
|
875 |
+
python train.py trainer.gpus=4 +trainer.strategy=ddp
|
876 |
+
```
|
877 |
+
|
878 |
+
⚠️ When using DDP you have to be careful how you write your models - learn more [here](https://pytorch-lightning.readthedocs.io/en/latest/advanced/multi_gpu.html).
|
879 |
+
|
880 |
+
<br>
|
881 |
+
|
882 |
+
### Docker
|
883 |
+
|
884 |
+
First you will need to [install Nvidia Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to enable GPU support.
|
885 |
+
|
886 |
+
The template Dockerfile is provided on branch [`dockerfiles`](https://github.com/ashleve/lightning-hydra-template/tree/dockerfiles). Copy it to the template root folder.
|
887 |
+
|
888 |
+
To build the container use:
|
889 |
+
|
890 |
+
```bash
|
891 |
+
docker build -t <project_name> .
|
892 |
+
```
|
893 |
+
|
894 |
+
To mount the project to the container use:
|
895 |
+
|
896 |
+
```bash
|
897 |
+
docker run -v $(pwd):/workspace/project --gpus all -it --rm <project_name>
|
898 |
+
```
|
899 |
+
|
900 |
+
<br>
|
901 |
+
|
902 |
+
### Reproducibility
|
903 |
+
|
904 |
+
What provides reproducibility:
|
905 |
+
|
906 |
+
- Hydra manages your configs
|
907 |
+
- Hydra manages your logging paths and makes every executed run store its hyperparameters and config overrides in a separate file in logs
|
908 |
+
- Single seed for random number generators in pytorch, numpy and python.random
|
909 |
+
- LightningDataModule allows you to encapsulate data split, transformations and default parameters in a single, clean abstraction
|
910 |
+
- LightningModule separates your research code from engineering code in a clean way
|
911 |
+
- Experiment tracking frameworks take care of logging metrics and hparams, some can also store results and artifacts in cloud
|
912 |
+
- Pytorch Lightning takes care of creating training checkpoints
|
913 |
+
- Example callbacks for wandb show how you can save and upload a snapshot of codebase every time the run is executed, as well as upload ckpts and track model gradients
|
914 |
+
|
915 |
+
<!--
|
916 |
+
You can load the config of previous run using:
|
917 |
+
|
918 |
+
```bash
|
919 |
+
python train.py --config-path /logs/runs/.../.hydra/ --config-name config.yaml
|
920 |
+
```
|
921 |
+
|
922 |
+
The `config.yaml` from `.hydra` folder contains all overriden parameters and sections. This approach however is not officially supported by Hydra and doesn't override the `hydra/` part of the config, meaning logging paths will revert to default!
|
923 |
+
-->
|
924 |
+
<br>
|
925 |
+
|
926 |
+
### Limitations
|
927 |
+
|
928 |
+
- Currently, template doesn't support k-fold cross validation, but it's possible to achieve it with Lightning Loop interface. See the [official example](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/loop_examples/kfold.py). Implementing it requires rewriting the training pipeline.
|
929 |
+
- Pytorch Lightning might not be the best choice for scalable reinforcement learning, it's probably better to use something like [Ray](https://github.com/ray-project/ray).
|
930 |
+
- Currently hyperparameter search with Hydra Optuna Plugin doesn't support prunning.
|
931 |
+
- Hydra changes working directory to new logging folder for every executed run, which might not be compatible with the way some libraries work.
|
932 |
+
|
933 |
+
<br>
|
934 |
+
|
935 |
+
## Useful Tricks
|
936 |
+
|
937 |
+
<details>
|
938 |
+
<summary><b>Accessing datamodule attributes in model</b></summary>
|
939 |
+
|
940 |
+
1. The simplest way is to pass datamodule attribute directly to model on initialization:
|
941 |
+
|
942 |
+
```python
|
943 |
+
# ./src/training_pipeline.py
|
944 |
+
datamodule = hydra.utils.instantiate(config.datamodule)
|
945 |
+
model = hydra.utils.instantiate(config.model, some_param=datamodule.some_param)
|
946 |
+
```
|
947 |
+
|
948 |
+
This is not a very robust solution, since it assumes all your datamodules have `some_param` attribute available (otherwise the run will crash).
|
949 |
+
|
950 |
+
2. If you only want to access datamodule config, you can simply pass it as an init parameter:
|
951 |
+
|
952 |
+
```python
|
953 |
+
# ./src/training_pipeline.py
|
954 |
+
model = hydra.utils.instantiate(config.model, dm_conf=config.datamodule, _recursive_=False)
|
955 |
+
```
|
956 |
+
|
957 |
+
Now you can access any datamodule config part like this:
|
958 |
+
|
959 |
+
```python
|
960 |
+
# ./src/models/my_model.py
|
961 |
+
class MyLitModel(LightningModule):
|
962 |
+
def __init__(self, dm_conf, param1, param2):
|
963 |
+
super().__init__()
|
964 |
+
|
965 |
+
batch_size = dm_conf.batch_size
|
966 |
+
```
|
967 |
+
|
968 |
+
3. If you need to access the datamodule object attributes, a little hacky solution is to add Omegaconf resolver to your datamodule:
|
969 |
+
|
970 |
+
```python
|
971 |
+
# ./src/datamodules/my_datamodule.py
|
972 |
+
from omegaconf import OmegaConf
|
973 |
+
|
974 |
+
class MyDataModule(LightningDataModule):
|
975 |
+
def __init__(self, param1, param2):
|
976 |
+
super().__init__()
|
977 |
+
|
978 |
+
self.param1 = param1
|
979 |
+
|
980 |
+
resolver_name = "datamodule"
|
981 |
+
OmegaConf.register_new_resolver(
|
982 |
+
resolver_name,
|
983 |
+
lambda name: getattr(self, name),
|
984 |
+
use_cache=False
|
985 |
+
)
|
986 |
+
```
|
987 |
+
|
988 |
+
This way you can reference any datamodule attribute from your config like this:
|
989 |
+
|
990 |
+
```yaml
|
991 |
+
# this will return attribute 'param1' from datamodule object
|
992 |
+
param1: ${datamodule: param1}
|
993 |
+
```
|
994 |
+
|
995 |
+
When later accessing this field, say in your lightning model, it will get automatically resolved based on all resolvers that are registered. Remember not to access this field before datamodule is initialized or it will crash. **You also need to set `resolve=False` in `print_config()` in [train.py](train.py) or it will throw errors:**
|
996 |
+
|
997 |
+
```python
|
998 |
+
# ./src/train.py
|
999 |
+
utils.print_config(config, resolve=False)
|
1000 |
+
```
|
1001 |
+
|
1002 |
+
</details>
|
1003 |
+
|
1004 |
+
<details>
|
1005 |
+
<summary><b>Automatic activation of virtual environment and tab completion when entering folder</b></summary>
|
1006 |
+
|
1007 |
+
1. Create a new file called `.autoenv` (this name is excluded from version control in `.gitignore`). <br>
|
1008 |
+
You can use it to automatically execute shell commands when entering folder. Add some commands to your `.autoenv` file, like in the example below:
|
1009 |
+
|
1010 |
+
```bash
|
1011 |
+
# activate conda environment
|
1012 |
+
conda activate myenv
|
1013 |
+
|
1014 |
+
# activate hydra tab completion for bash
|
1015 |
+
eval "$(python train.py -sc install=bash)"
|
1016 |
+
```
|
1017 |
+
|
1018 |
+
(these commands will be executed whenever you're openning or switching terminal to folder containing `.autoenv` file)
|
1019 |
+
|
1020 |
+
2. To setup this automation for bash, execute the following line (it will append your `.bashrc` file):
|
1021 |
+
|
1022 |
+
```bash
|
1023 |
+
echo "autoenv() { if [ -x .autoenv ]; then source .autoenv ; echo '.autoenv executed' ; fi } ; cd() { builtin cd \"\$@\" ; autoenv ; } ; autoenv" >> ~/.bashrc
|
1024 |
+
```
|
1025 |
+
|
1026 |
+
3. Lastly add execution previliges to your `.autoenv` file:
|
1027 |
+
|
1028 |
+
```
|
1029 |
+
chmod +x .autoenv
|
1030 |
+
```
|
1031 |
+
|
1032 |
+
(for safety, only `.autoenv` with previligies will be executed)
|
1033 |
+
|
1034 |
+
**Explanation**
|
1035 |
+
|
1036 |
+
The mentioned line appends your `.bashrc` file with 2 commands:
|
1037 |
+
|
1038 |
+
1. `autoenv() { if [ -x .autoenv ]; then source .autoenv ; echo '.autoenv executed' ; fi }` - this declares the `autoenv()` function, which executes `.autoenv` file if it exists in current work dir and has execution previligies
|
1039 |
+
2. `cd() { builtin cd \"\$@\" ; autoenv ; } ; autoenv` - this extends behaviour of `cd` command, to make it execute `autoenv()` function each time you change folder in terminal or open new terminal
|
1040 |
+
|
1041 |
+
</details>
|
1042 |
+
|
1043 |
+
<!--
|
1044 |
+
<details>
|
1045 |
+
<summary><b>Making sweeps failure resistant</b></summary>
|
1046 |
+
|
1047 |
+
TODO
|
1048 |
+
|
1049 |
+
</details>
|
1050 |
+
-->
|
1051 |
+
|
1052 |
+
<br>
|
1053 |
+
|
1054 |
+
## Best Practices
|
1055 |
+
|
1056 |
+
<details>
|
1057 |
+
<summary><b>Use Miniconda for GPU environments</b></summary>
|
1058 |
+
|
1059 |
+
Use miniconda for your python environments (it's usually unnecessary to install full anaconda environment, miniconda should be enough).
|
1060 |
+
It makes it easier to install some dependencies, like cudatoolkit for GPU support. It also allows you to acccess your environments globally.
|
1061 |
+
|
1062 |
+
Example installation:
|
1063 |
+
|
1064 |
+
```bash
|
1065 |
+
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
|
1066 |
+
bash Miniconda3-latest-Linux-x86_64.sh
|
1067 |
+
```
|
1068 |
+
|
1069 |
+
Create new conda environment:
|
1070 |
+
|
1071 |
+
```bash
|
1072 |
+
conda create -n myenv python=3.8
|
1073 |
+
conda activate myenv
|
1074 |
+
```
|
1075 |
+
|
1076 |
+
</details>
|
1077 |
+
|
1078 |
+
<details>
|
1079 |
+
<summary><b>Use automatic code formatting</b></summary>
|
1080 |
+
|
1081 |
+
Use pre-commit hooks to standardize code formatting of your project and save mental energy.<br>
|
1082 |
+
Simply install pre-commit package with:
|
1083 |
+
|
1084 |
+
```bash
|
1085 |
+
pip install pre-commit
|
1086 |
+
```
|
1087 |
+
|
1088 |
+
Next, install hooks from [.pre-commit-config.yaml](.pre-commit-config.yaml):
|
1089 |
+
|
1090 |
+
```bash
|
1091 |
+
pre-commit install
|
1092 |
+
```
|
1093 |
+
|
1094 |
+
After that your code will be automatically reformatted on every new commit.<br>
|
1095 |
+
Currently template contains configurations of **black** (python code formatting), **isort** (python import sorting), **flake8** (python code analysis), **prettier** (yaml formating) and **nbstripout** (clearing output from jupyter notebooks). <br>
|
1096 |
+
|
1097 |
+
To reformat all files in the project use command:
|
1098 |
+
|
1099 |
+
```bash
|
1100 |
+
pre-commit run -a
|
1101 |
+
```
|
1102 |
+
|
1103 |
+
</details>
|
1104 |
+
|
1105 |
+
<details>
|
1106 |
+
<summary><b>Set private environment variables in .env file</b></summary>
|
1107 |
+
|
1108 |
+
System specific variables (e.g. absolute paths to datasets) should not be under version control or it will result in conflict between different users. Your private keys also shouldn't be versioned since you don't want them to be leaked.<br>
|
1109 |
+
|
1110 |
+
Template contains `.env.example` file, which serves as an example. Create a new file called `.env` (this name is excluded from version control in .gitignore).
|
1111 |
+
You should use it for storing environment variables like this:
|
1112 |
+
|
1113 |
+
```
|
1114 |
+
MY_VAR=/home/user/my_system_path
|
1115 |
+
```
|
1116 |
+
|
1117 |
+
All variables from `.env` are loaded in `train.py` automatically.
|
1118 |
+
|
1119 |
+
Hydra allows you to reference any env variable in `.yaml` configs like this:
|
1120 |
+
|
1121 |
+
```yaml
|
1122 |
+
path_to_data: ${oc.env:MY_VAR}
|
1123 |
+
```
|
1124 |
+
|
1125 |
+
</details>
|
1126 |
+
|
1127 |
+
<details>
|
1128 |
+
<summary><b>Name metrics using '/' character</b></summary>
|
1129 |
+
|
1130 |
+
Depending on which logger you're using, it's often useful to define metric name with `/` character:
|
1131 |
+
|
1132 |
+
```python
|
1133 |
+
self.log("train/loss", loss)
|
1134 |
+
```
|
1135 |
+
|
1136 |
+
This way loggers will treat your metrics as belonging to different sections, which helps to get them organised in UI.
|
1137 |
+
|
1138 |
+
</details>
|
1139 |
+
|
1140 |
+
<details>
|
1141 |
+
<summary><b>Use torchmetrics</b></summary>
|
1142 |
+
|
1143 |
+
Use official [torchmetrics](https://github.com/PytorchLightning/metrics) library to ensure proper calculation of metrics. This is especially important for multi-GPU training!
|
1144 |
+
|
1145 |
+
For example, instead of calculating accuracy by yourself, you should use the provided `Accuracy` class like this:
|
1146 |
+
|
1147 |
+
```python
|
1148 |
+
from torchmetrics.classification.accuracy import Accuracy
|
1149 |
+
|
1150 |
+
|
1151 |
+
class LitModel(LightningModule):
|
1152 |
+
def __init__(self)
|
1153 |
+
self.train_acc = Accuracy()
|
1154 |
+
self.val_acc = Accuracy()
|
1155 |
+
|
1156 |
+
def training_step(self, batch, batch_idx):
|
1157 |
+
...
|
1158 |
+
acc = self.train_acc(predictions, targets)
|
1159 |
+
self.log("train/acc", acc)
|
1160 |
+
...
|
1161 |
+
|
1162 |
+
def validation_step(self, batch, batch_idx):
|
1163 |
+
...
|
1164 |
+
acc = self.val_acc(predictions, targets)
|
1165 |
+
self.log("val/acc", acc)
|
1166 |
+
...
|
1167 |
+
```
|
1168 |
+
|
1169 |
+
Make sure to use different metric instance for each step to ensure proper value reduction over all GPU processes.
|
1170 |
+
|
1171 |
+
Torchmetrics provides metrics for most use cases, like F1 score or confusion matrix. Read [documentation](https://torchmetrics.readthedocs.io/en/latest/#more-reading) for more.
|
1172 |
+
|
1173 |
+
</details>
|
1174 |
+
|
1175 |
+
<details>
|
1176 |
+
<summary><b>Follow PyTorch Lightning style guide</b></summary>
|
1177 |
+
|
1178 |
+
The style guide is available [here](https://pytorch-lightning.readthedocs.io/en/latest/starter/style_guide.html).<br>
|
1179 |
+
|
1180 |
+
1. Be explicit in your init. Try to define all the relevant defaults so that the user doesn’t have to guess. Provide type hints. This way your module is reusable across projects!
|
1181 |
+
|
1182 |
+
```python
|
1183 |
+
class LitModel(LightningModule):
|
1184 |
+
def __init__(self, layer_size: int = 256, lr: float = 0.001):
|
1185 |
+
```
|
1186 |
+
|
1187 |
+
2. Preserve the recommended method order.
|
1188 |
+
|
1189 |
+
```python
|
1190 |
+
class LitModel(LightningModule):
|
1191 |
+
|
1192 |
+
def __init__():
|
1193 |
+
...
|
1194 |
+
|
1195 |
+
def forward():
|
1196 |
+
...
|
1197 |
+
|
1198 |
+
def training_step():
|
1199 |
+
...
|
1200 |
+
|
1201 |
+
def training_step_end():
|
1202 |
+
...
|
1203 |
+
|
1204 |
+
def training_epoch_end():
|
1205 |
+
...
|
1206 |
+
|
1207 |
+
def validation_step():
|
1208 |
+
...
|
1209 |
+
|
1210 |
+
def validation_step_end():
|
1211 |
+
...
|
1212 |
+
|
1213 |
+
def validation_epoch_end():
|
1214 |
+
...
|
1215 |
+
|
1216 |
+
def test_step():
|
1217 |
+
...
|
1218 |
+
|
1219 |
+
def test_step_end():
|
1220 |
+
...
|
1221 |
+
|
1222 |
+
def test_epoch_end():
|
1223 |
+
...
|
1224 |
+
|
1225 |
+
def configure_optimizers():
|
1226 |
+
...
|
1227 |
+
|
1228 |
+
def any_extra_hook():
|
1229 |
+
...
|
1230 |
+
```
|
1231 |
+
|
1232 |
+
</details>
|
1233 |
+
|
1234 |
+
<details>
|
1235 |
+
<summary><b>Version control your data and models with DVC</b></summary>
|
1236 |
+
|
1237 |
+
Use [DVC](https://dvc.org) to version control big files, like your data or trained ML models.<br>
|
1238 |
+
To initialize the dvc repository:
|
1239 |
+
|
1240 |
+
```bash
|
1241 |
+
dvc init
|
1242 |
+
```
|
1243 |
+
|
1244 |
+
To start tracking a file or directory, use `dvc add`:
|
1245 |
+
|
1246 |
+
```bash
|
1247 |
+
dvc add data/MNIST
|
1248 |
+
```
|
1249 |
+
|
1250 |
+
DVC stores information about the added file (or a directory) in a special .dvc file named data/MNIST.dvc, a small text file with a human-readable format. This file can be easily versioned like source code with Git, as a placeholder for the original data:
|
1251 |
+
|
1252 |
+
```bash
|
1253 |
+
git add data/MNIST.dvc data/.gitignore
|
1254 |
+
git commit -m "Add raw data"
|
1255 |
+
```
|
1256 |
+
|
1257 |
+
</details>
|
1258 |
+
|
1259 |
+
<details>
|
1260 |
+
<summary><b>Support installing project as a package</b></summary>
|
1261 |
+
|
1262 |
+
It allows other people to easily use your modules in their own projects.
|
1263 |
+
Change name of the `src` folder to your project name and add `setup.py` file:
|
1264 |
+
|
1265 |
+
```python
|
1266 |
+
from setuptools import find_packages, setup
|
1267 |
+
|
1268 |
+
|
1269 |
+
setup(
|
1270 |
+
name="src", # change "src" folder name to your project name
|
1271 |
+
version="0.0.0",
|
1272 |
+
description="Describe Your Cool Project",
|
1273 |
+
author="...",
|
1274 |
+
author_email="...",
|
1275 |
+
url="https://github.com/ashleve/lightning-hydra-template", # replace with your own github project link
|
1276 |
+
install_requires=[
|
1277 |
+
"pytorch>=1.10.0",
|
1278 |
+
"pytorch-lightning>=1.4.0",
|
1279 |
+
"hydra-core>=1.1.0",
|
1280 |
+
],
|
1281 |
+
packages=find_packages(),
|
1282 |
+
)
|
1283 |
+
```
|
1284 |
+
|
1285 |
+
Now your project can be installed from local files:
|
1286 |
+
|
1287 |
+
```bash
|
1288 |
+
pip install -e .
|
1289 |
+
```
|
1290 |
+
|
1291 |
+
Or directly from git repository:
|
1292 |
+
|
1293 |
+
```bash
|
1294 |
+
pip install git+git://github.com/YourGithubName/your-repo-name.git --upgrade
|
1295 |
+
```
|
1296 |
+
|
1297 |
+
So any file can be easily imported into any other file like so:
|
1298 |
+
|
1299 |
+
```python
|
1300 |
+
from project_name.models.mnist_module import MNISTLitModule
|
1301 |
+
from project_name.datamodules.mnist_datamodule import MNISTDataModule
|
1302 |
+
```
|
1303 |
+
|
1304 |
+
</details>
|
1305 |
+
|
1306 |
+
<!-- <details>
|
1307 |
+
<summary><b>Make notebooks independent from other files</b></summary>
|
1308 |
+
|
1309 |
+
It's a good practice for jupyter notebooks to be portable. Try to make them independent from src files. If you need to access external code, try to embed it inside the notebook.
|
1310 |
+
|
1311 |
+
</details> -->
|
1312 |
+
|
1313 |
+
<!--<details>
|
1314 |
+
<summary><b>Use Docker</b></summary>
|
1315 |
+
|
1316 |
+
Docker makes it easy to initialize the whole training environment, e.g. when you want to execute experiments in cloud or on some private computing cluster. You can extend [dockerfiles](https://github.com/ashleve/lightning-hydra-template/tree/dockerfiles) provided in the template with your own instructions for building the image.<br>
|
1317 |
+
|
1318 |
+
</details> -->
|
1319 |
+
|
1320 |
+
<br>
|
1321 |
+
|
1322 |
+
## Other Repositories
|
1323 |
+
|
1324 |
+
<details>
|
1325 |
+
<summary><b>Inspirations</b></summary>
|
1326 |
+
|
1327 |
+
This template was inspired by:
|
1328 |
+
[PyTorchLightning/deep-learninig-project-template](https://github.com/PyTorchLightning/deep-learning-project-template),
|
1329 |
+
[drivendata/cookiecutter-data-science](https://github.com/drivendata/cookiecutter-data-science),
|
1330 |
+
[tchaton/lightning-hydra-seed](https://github.com/tchaton/lightning-hydra-seed),
|
1331 |
+
[Erlemar/pytorch_tempest](https://github.com/Erlemar/pytorch_tempest),
|
1332 |
+
[lucmos/nn-template](https://github.com/lucmos/nn-template).
|
1333 |
+
|
1334 |
+
</details>
|
1335 |
+
|
1336 |
+
<details>
|
1337 |
+
<summary><b>Useful repositories</b></summary>
|
1338 |
+
|
1339 |
+
- [pytorch/hydra-torch](https://github.com/pytorch/hydra-torch) - resources for configuring PyTorch classes with Hydra,
|
1340 |
+
- [romesco/hydra-lightning](https://github.com/romesco/hydra-lightning) - resources for configuring PyTorch Lightning classes with Hydra
|
1341 |
+
- [lucmos/nn-template](https://github.com/lucmos/nn-template) - similar template
|
1342 |
+
- [PyTorchLightning/lightning-transformers](https://github.com/PyTorchLightning/lightning-transformers) - official Lightning Transformers repo built with Hydra
|
1343 |
+
|
1344 |
+
</details>
|
1345 |
+
|
1346 |
+
<!-- ## :star: Stargazers Over Time
|
1347 |
+
[![Stargazers over time](https://starchart.cc/ashleve/lightning-hydra-template.svg)](https://starchart.cc/ashleve/lightning-hydra-template) -->
|
1348 |
+
|
1349 |
+
<br>
|
1350 |
+
|
1351 |
+
## License
|
1352 |
+
|
1353 |
+
This project is licensed under the MIT License.
|
1354 |
+
|
1355 |
+
```
|
1356 |
+
MIT License
|
1357 |
+
|
1358 |
+
Copyright (c) 2021 ashleve
|
1359 |
+
|
1360 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
1361 |
+
of this software and associated documentation files (the "Software"), to deal
|
1362 |
+
in the Software without restriction, including without limitation the rights
|
1363 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
1364 |
+
copies of the Software, and to permit persons to whom the Software is
|
1365 |
+
furnished to do so, subject to the following conditions:
|
1366 |
+
|
1367 |
+
The above copyright notice and this permission notice shall be included in all
|
1368 |
+
copies or substantial portions of the Software.
|
1369 |
+
|
1370 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
1371 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
1372 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
1373 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
1374 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
1375 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
1376 |
+
SOFTWARE.
|
1377 |
+
```
|
1378 |
+
|
1379 |
+
<br>
|
1380 |
+
<br>
|
1381 |
+
<br>
|
1382 |
+
<br>
|
1383 |
+
|
1384 |
+
**DELETE EVERYTHING ABOVE FOR YOUR PROJECT**
|
1385 |
+
|
1386 |
+
---
|
1387 |
+
|
1388 |
+
<div align="center">
|
1389 |
+
|
1390 |
+
# Your Project Name
|
1391 |
+
|
1392 |
+
<a href="https://pytorch.org/get-started/locally/"><img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-ee4c2c?logo=pytorch&logoColor=white"></a>
|
1393 |
+
<a href="https://pytorchlightning.ai/"><img alt="Lightning" src="https://img.shields.io/badge/-Lightning-792ee5?logo=pytorchlightning&logoColor=white"></a>
|
1394 |
+
<a href="https://hydra.cc/"><img alt="Config: Hydra" src="https://img.shields.io/badge/Config-Hydra-89b8cd"></a>
|
1395 |
+
<a href="https://github.com/ashleve/lightning-hydra-template"><img alt="Template" src="https://img.shields.io/badge/-Lightning--Hydra--Template-017F2F?style=flat&logo=github&labelColor=gray"></a><br>
|
1396 |
+
[![Paper](http://img.shields.io/badge/paper-arxiv.1001.2234-B31B1B.svg)](https://www.nature.com/articles/nature14539)
|
1397 |
+
[![Conference](http://img.shields.io/badge/AnyConference-year-4b44ce.svg)](https://papers.nips.cc/paper/2020)
|
1398 |
+
|
1399 |
+
</div>
|
1400 |
+
|
1401 |
+
## Description
|
1402 |
+
|
1403 |
+
What it does
|
1404 |
+
|
1405 |
+
## How to run
|
1406 |
+
|
1407 |
+
Install dependencies
|
1408 |
+
|
1409 |
+
```bash
|
1410 |
+
# clone project
|
1411 |
+
git clone https://github.com/YourGithubName/your-repo-name
|
1412 |
+
cd your-repo-name
|
1413 |
+
|
1414 |
+
# [OPTIONAL] create conda environment
|
1415 |
+
conda create -n myenv python=3.8
|
1416 |
+
conda activate myenv
|
1417 |
+
|
1418 |
+
# install pytorch according to instructions
|
1419 |
+
# https://pytorch.org/get-started/
|
1420 |
+
|
1421 |
+
# install requirements
|
1422 |
+
pip install -r requirements.txt
|
1423 |
+
```
|
1424 |
+
|
1425 |
+
Train model with default configuration
|
1426 |
+
|
1427 |
+
```bash
|
1428 |
+
# train on CPU
|
1429 |
+
python train.py trainer.gpus=0
|
1430 |
+
|
1431 |
+
# train on GPU
|
1432 |
+
python train.py trainer.gpus=1
|
1433 |
+
```
|
1434 |
+
|
1435 |
+
Train model with chosen experiment configuration from [configs/experiment/](configs/experiment/)
|
1436 |
+
|
1437 |
+
```bash
|
1438 |
+
python train.py experiment=experiment_name.yaml
|
1439 |
+
```
|
1440 |
+
|
1441 |
+
You can override any parameter from command line like this
|
1442 |
+
|
1443 |
+
```bash
|
1444 |
+
python train.py trainer.max_epochs=20 datamodule.batch_size=64
|
1445 |
+
```
|
models/configs/callbacks/default.yaml
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_checkpoint:
|
2 |
+
_target_: pytorch_lightning.callbacks.ModelCheckpoint
|
3 |
+
monitor: "val/acc" # name of the logged metric which determines when model is improving
|
4 |
+
mode: "max" # "max" means higher metric value is better, can be also "min"
|
5 |
+
save_top_k: 1 # save k best models (determined by above metric)
|
6 |
+
save_last: True # additionaly always save model from last epoch
|
7 |
+
verbose: False
|
8 |
+
dirpath: "checkpoints/"
|
9 |
+
filename: "epoch_{epoch:03d}"
|
10 |
+
auto_insert_metric_name: False
|
11 |
+
|
12 |
+
early_stopping:
|
13 |
+
_target_: pytorch_lightning.callbacks.EarlyStopping
|
14 |
+
monitor: "val/acc" # name of the logged metric which determines when model is improving
|
15 |
+
mode: "max" # "max" means higher metric value is better, can be also "min"
|
16 |
+
patience: 100 # how many validation epochs of not improving until training stops
|
17 |
+
min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement
|
18 |
+
|
19 |
+
model_summary:
|
20 |
+
_target_: pytorch_lightning.callbacks.RichModelSummary
|
21 |
+
max_depth: -1
|
22 |
+
|
23 |
+
rich_progress_bar:
|
24 |
+
_target_: pytorch_lightning.callbacks.RichProgressBar
|
models/configs/callbacks/none.yaml
ADDED
File without changes
|
models/configs/datamodule/mnist.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.datamodules.mnist_datamodule.MNISTDataModule
|
2 |
+
|
3 |
+
data_dir: ${data_dir} # data_dir is specified in config.yaml
|
4 |
+
batch_size: 64
|
5 |
+
train_val_test_split: [55_000, 5_000, 10_000]
|
6 |
+
num_workers: 0
|
7 |
+
pin_memory: False
|
models/configs/debug/default.yaml
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# default debugging setup, runs 1 full epoch
|
4 |
+
# other debugging configs can inherit from this one
|
5 |
+
|
6 |
+
defaults:
|
7 |
+
- override /log_dir: debug.yaml
|
8 |
+
|
9 |
+
trainer:
|
10 |
+
max_epochs: 1
|
11 |
+
gpus: 0 # debuggers don't like gpus
|
12 |
+
detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor
|
13 |
+
track_grad_norm: 2 # track gradient norm with loggers
|
14 |
+
|
15 |
+
datamodule:
|
16 |
+
num_workers: 0 # debuggers don't like multiprocessing
|
17 |
+
pin_memory: False # disable gpu memory pin
|
18 |
+
|
19 |
+
# sets level of all command line loggers to 'DEBUG'
|
20 |
+
# https://hydra.cc/docs/tutorials/basic/running_your_app/logging/
|
21 |
+
hydra:
|
22 |
+
verbose: True
|
23 |
+
|
24 |
+
# use this to set level of only chosen command line loggers to 'DEBUG':
|
25 |
+
# verbose: [src.train, src.utils]
|
26 |
+
|
27 |
+
# config is already printed by hydra when `hydra/verbose: True`
|
28 |
+
print_config: False
|
models/configs/debug/limit_batches.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# uses only 1% of the training data and 5% of validation/test data
|
4 |
+
|
5 |
+
defaults:
|
6 |
+
- default.yaml
|
7 |
+
|
8 |
+
trainer:
|
9 |
+
max_epochs: 3
|
10 |
+
limit_train_batches: 0.01
|
11 |
+
limit_val_batches: 0.05
|
12 |
+
limit_test_batches: 0.05
|
models/configs/debug/overfit.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# overfits to 3 batches
|
4 |
+
|
5 |
+
defaults:
|
6 |
+
- default.yaml
|
7 |
+
|
8 |
+
trainer:
|
9 |
+
max_epochs: 20
|
10 |
+
overfit_batches: 3
|
models/configs/debug/profiler.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# runs with execution time profiling
|
4 |
+
|
5 |
+
defaults:
|
6 |
+
- default.yaml
|
7 |
+
|
8 |
+
trainer:
|
9 |
+
max_epochs: 1
|
10 |
+
profiler: "simple"
|
11 |
+
# profiler: "advanced"
|
12 |
+
# profiler: "pytorch"
|
models/configs/debug/step.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# runs 1 train, 1 validation and 1 test step
|
4 |
+
|
5 |
+
defaults:
|
6 |
+
- default.yaml
|
7 |
+
|
8 |
+
trainer:
|
9 |
+
fast_dev_run: true
|
models/configs/debug/test_only.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# runs only test epoch
|
4 |
+
|
5 |
+
defaults:
|
6 |
+
- default.yaml
|
7 |
+
|
8 |
+
train: False
|
9 |
+
test: True
|
models/configs/experiment/example.yaml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# to execute this experiment run:
|
4 |
+
# python train.py experiment=example
|
5 |
+
|
6 |
+
defaults:
|
7 |
+
- override /datamodule: mnist.yaml
|
8 |
+
- override /model: mnist.yaml
|
9 |
+
- override /callbacks: default.yaml
|
10 |
+
- override /logger: null
|
11 |
+
- override /trainer: default.yaml
|
12 |
+
|
13 |
+
# all parameters below will be merged with parameters from default configurations set above
|
14 |
+
# this allows you to overwrite only specified parameters
|
15 |
+
|
16 |
+
# name of the run determines folder name in logs
|
17 |
+
name: "simple_dense_net"
|
18 |
+
|
19 |
+
seed: 12345
|
20 |
+
|
21 |
+
trainer:
|
22 |
+
min_epochs: 10
|
23 |
+
max_epochs: 10
|
24 |
+
gradient_clip_val: 0.5
|
25 |
+
|
26 |
+
model:
|
27 |
+
lin1_size: 128
|
28 |
+
lin2_size: 256
|
29 |
+
lin3_size: 64
|
30 |
+
lr: 0.002
|
31 |
+
|
32 |
+
datamodule:
|
33 |
+
batch_size: 64
|
34 |
+
|
35 |
+
logger:
|
36 |
+
wandb:
|
37 |
+
tags: ["mnist", "${name}"]
|
models/configs/hparams_search/mnist_optuna.yaml
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# example hyperparameter optimization of some experiment with Optuna:
|
4 |
+
# python train.py -m hparams_search=mnist_optuna experiment=example
|
5 |
+
|
6 |
+
defaults:
|
7 |
+
- override /hydra/sweeper: optuna
|
8 |
+
|
9 |
+
# choose metric which will be optimized by Optuna
|
10 |
+
# make sure this is the correct name of some metric logged in lightning module!
|
11 |
+
optimized_metric: "val/acc_best"
|
12 |
+
|
13 |
+
# here we define Optuna hyperparameter search
|
14 |
+
# it optimizes for value returned from function with @hydra.main decorator
|
15 |
+
# docs: https://hydra.cc/docs/next/plugins/optuna_sweeper
|
16 |
+
hydra:
|
17 |
+
sweeper:
|
18 |
+
_target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
|
19 |
+
|
20 |
+
# storage URL to persist optimization results
|
21 |
+
# for example, you can use SQLite if you set 'sqlite:///example.db'
|
22 |
+
storage: null
|
23 |
+
|
24 |
+
# name of the study to persist optimization results
|
25 |
+
study_name: null
|
26 |
+
|
27 |
+
# number of parallel workers
|
28 |
+
n_jobs: 1
|
29 |
+
|
30 |
+
# 'minimize' or 'maximize' the objective
|
31 |
+
direction: maximize
|
32 |
+
|
33 |
+
# total number of runs that will be executed
|
34 |
+
n_trials: 25
|
35 |
+
|
36 |
+
# choose Optuna hyperparameter sampler
|
37 |
+
# docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html
|
38 |
+
sampler:
|
39 |
+
_target_: optuna.samplers.TPESampler
|
40 |
+
seed: 12345
|
41 |
+
n_startup_trials: 10 # number of random sampling runs before optimization starts
|
42 |
+
|
43 |
+
# define range of hyperparameters
|
44 |
+
search_space:
|
45 |
+
datamodule.batch_size:
|
46 |
+
type: categorical
|
47 |
+
choices: [32, 64, 128]
|
48 |
+
model.lr:
|
49 |
+
type: float
|
50 |
+
low: 0.0001
|
51 |
+
high: 0.2
|
52 |
+
model.lin1_size:
|
53 |
+
type: categorical
|
54 |
+
choices: [32, 64, 128, 256, 512]
|
55 |
+
model.lin2_size:
|
56 |
+
type: categorical
|
57 |
+
choices: [32, 64, 128, 256, 512]
|
58 |
+
model.lin3_size:
|
59 |
+
type: categorical
|
60 |
+
choices: [32, 64, 128, 256, 512]
|
models/configs/local/.gitkeep
ADDED
File without changes
|
models/configs/log_dir/debug.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
run:
|
5 |
+
dir: logs/debugs/runs/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}
|
6 |
+
sweep:
|
7 |
+
dir: logs/debugs/multiruns/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}
|
8 |
+
subdir: ${hydra.job.num}
|
models/configs/log_dir/default.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
run:
|
5 |
+
dir: logs/experiments/runs/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}
|
6 |
+
sweep:
|
7 |
+
dir: logs/experiments/multiruns/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}
|
8 |
+
subdir: ${hydra.job.num}
|
models/configs/log_dir/evaluation.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
run:
|
5 |
+
dir: logs/evaluations/runs/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}
|
6 |
+
sweep:
|
7 |
+
dir: logs/evaluations/multiruns/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}
|
8 |
+
subdir: ${hydra.job.num}
|
models/configs/logger/comet.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://www.comet.ml
|
2 |
+
|
3 |
+
comet:
|
4 |
+
_target_: pytorch_lightning.loggers.comet.CometLogger
|
5 |
+
api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable
|
6 |
+
project_name: "template-tests"
|
7 |
+
experiment_name: ${name}
|
models/configs/logger/csv.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# csv logger built in lightning
|
2 |
+
|
3 |
+
csv:
|
4 |
+
_target_: pytorch_lightning.loggers.csv_logs.CSVLogger
|
5 |
+
save_dir: "."
|
6 |
+
name: "csv/"
|
7 |
+
prefix: ""
|
models/configs/logger/many_loggers.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# train with many loggers at once
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
# - comet.yaml
|
5 |
+
- csv.yaml
|
6 |
+
# - mlflow.yaml
|
7 |
+
# - neptune.yaml
|
8 |
+
- tensorboard.yaml
|
9 |
+
- wandb.yaml
|
models/configs/logger/mlflow.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://mlflow.org
|
2 |
+
|
3 |
+
mlflow:
|
4 |
+
_target_: pytorch_lightning.loggers.mlflow.MLFlowLogger
|
5 |
+
experiment_name: ${name}
|
6 |
+
tracking_uri: null
|
7 |
+
tags: null
|
8 |
+
save_dir: ./mlruns
|
9 |
+
prefix: ""
|
10 |
+
artifact_location: null
|