Hannes Kuchelmeister
commited on
Commit
·
b72a776
1
Parent(s):
3d5c288
add model to repository
Browse files- model/LICENSE +21 -0
- model/README.md +378 -0
- model/base/__init__.py +3 -0
- model/base/base_data_loader.py +61 -0
- model/base/base_model.py +25 -0
- model/base/base_trainer.py +151 -0
- model/config.json +50 -0
- model/data_loader/data_loaders.py +16 -0
- model/logger/__init__.py +2 -0
- model/logger/logger.py +22 -0
- model/logger/logger_config.json +32 -0
- model/logger/visualization.py +73 -0
- model/model/loss.py +5 -0
- model/model/metric.py +20 -0
- model/model/model.py +22 -0
- model/new_project.py +18 -0
- model/parse_config.py +157 -0
- model/requirements.txt +5 -0
- model/test.py +81 -0
- model/train.py +73 -0
- model/trainer/__init__.py +1 -0
- model/trainer/trainer.py +110 -0
- model/utils/__init__.py +1 -0
- model/utils/util.py +67 -0
model/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2018 Victor Huang
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
model/README.md
ADDED
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# PyTorch Template Project
|
2 |
+
PyTorch deep learning project made easy.
|
3 |
+
|
4 |
+
<!-- @import "[TOC]" {cmd="toc" depthFrom=1 depthTo=6 orderedList=false} -->
|
5 |
+
|
6 |
+
<!-- code_chunk_output -->
|
7 |
+
|
8 |
+
* [PyTorch Template Project](#pytorch-template-project)
|
9 |
+
* [Requirements](#requirements)
|
10 |
+
* [Features](#features)
|
11 |
+
* [Folder Structure](#folder-structure)
|
12 |
+
* [Usage](#usage)
|
13 |
+
* [Config file format](#config-file-format)
|
14 |
+
* [Using config files](#using-config-files)
|
15 |
+
* [Resuming from checkpoints](#resuming-from-checkpoints)
|
16 |
+
* [Using Multiple GPU](#using-multiple-gpu)
|
17 |
+
* [Customization](#customization)
|
18 |
+
* [Custom CLI options](#custom-cli-options)
|
19 |
+
* [Data Loader](#data-loader)
|
20 |
+
* [Trainer](#trainer)
|
21 |
+
* [Model](#model)
|
22 |
+
* [Loss](#loss)
|
23 |
+
* [metrics](#metrics)
|
24 |
+
* [Additional logging](#additional-logging)
|
25 |
+
* [Validation data](#validation-data)
|
26 |
+
* [Checkpoints](#checkpoints)
|
27 |
+
* [Tensorboard Visualization](#tensorboard-visualization)
|
28 |
+
* [Contribution](#contribution)
|
29 |
+
* [TODOs](#todos)
|
30 |
+
* [License](#license)
|
31 |
+
* [Acknowledgements](#acknowledgements)
|
32 |
+
|
33 |
+
<!-- /code_chunk_output -->
|
34 |
+
|
35 |
+
## Requirements
|
36 |
+
* Python >= 3.5 (3.6 recommended)
|
37 |
+
* PyTorch >= 0.4 (1.2 recommended)
|
38 |
+
* tqdm (Optional for `test.py`)
|
39 |
+
* tensorboard >= 1.14 (see [Tensorboard Visualization](#tensorboard-visualization))
|
40 |
+
|
41 |
+
## Features
|
42 |
+
* Clear folder structure which is suitable for many deep learning projects.
|
43 |
+
* `.json` config file support for convenient parameter tuning.
|
44 |
+
* Customizable command line options for more convenient parameter tuning.
|
45 |
+
* Checkpoint saving and resuming.
|
46 |
+
* Abstract base classes for faster development:
|
47 |
+
* `BaseTrainer` handles checkpoint saving/resuming, training process logging, and more.
|
48 |
+
* `BaseDataLoader` handles batch generation, data shuffling, and validation data splitting.
|
49 |
+
* `BaseModel` provides basic model summary.
|
50 |
+
|
51 |
+
## Folder Structure
|
52 |
+
```
|
53 |
+
pytorch-template/
|
54 |
+
│
|
55 |
+
├── train.py - main script to start training
|
56 |
+
├── test.py - evaluation of trained model
|
57 |
+
│
|
58 |
+
├── config.json - holds configuration for training
|
59 |
+
├── parse_config.py - class to handle config file and cli options
|
60 |
+
│
|
61 |
+
├── new_project.py - initialize new project with template files
|
62 |
+
│
|
63 |
+
├── base/ - abstract base classes
|
64 |
+
│ ├── base_data_loader.py
|
65 |
+
│ ├── base_model.py
|
66 |
+
│ └── base_trainer.py
|
67 |
+
│
|
68 |
+
├── data_loader/ - anything about data loading goes here
|
69 |
+
│ └── data_loaders.py
|
70 |
+
│
|
71 |
+
├── data/ - default directory for storing input data
|
72 |
+
│
|
73 |
+
├── model/ - models, losses, and metrics
|
74 |
+
│ ├── model.py
|
75 |
+
│ ├── metric.py
|
76 |
+
│ └── loss.py
|
77 |
+
│
|
78 |
+
├── saved/
|
79 |
+
│ ├── models/ - trained models are saved here
|
80 |
+
│ └── log/ - default logdir for tensorboard and logging output
|
81 |
+
│
|
82 |
+
├── trainer/ - trainers
|
83 |
+
│ └── trainer.py
|
84 |
+
│
|
85 |
+
├── logger/ - module for tensorboard visualization and logging
|
86 |
+
│ ├── visualization.py
|
87 |
+
│ ├── logger.py
|
88 |
+
│ └── logger_config.json
|
89 |
+
│
|
90 |
+
└── utils/ - small utility functions
|
91 |
+
├── util.py
|
92 |
+
└── ...
|
93 |
+
```
|
94 |
+
|
95 |
+
## Usage
|
96 |
+
The code in this repo is an MNIST example of the template.
|
97 |
+
Try `python train.py -c config.json` to run code.
|
98 |
+
|
99 |
+
### Config file format
|
100 |
+
Config files are in `.json` format:
|
101 |
+
```javascript
|
102 |
+
{
|
103 |
+
"name": "Mnist_LeNet", // training session name
|
104 |
+
"n_gpu": 1, // number of GPUs to use for training.
|
105 |
+
|
106 |
+
"arch": {
|
107 |
+
"type": "MnistModel", // name of model architecture to train
|
108 |
+
"args": {
|
109 |
+
|
110 |
+
}
|
111 |
+
},
|
112 |
+
"data_loader": {
|
113 |
+
"type": "MnistDataLoader", // selecting data loader
|
114 |
+
"args":{
|
115 |
+
"data_dir": "data/", // dataset path
|
116 |
+
"batch_size": 64, // batch size
|
117 |
+
"shuffle": true, // shuffle training data before splitting
|
118 |
+
"validation_split": 0.1 // size of validation dataset. float(portion) or int(number of samples)
|
119 |
+
"num_workers": 2, // number of cpu processes to be used for data loading
|
120 |
+
}
|
121 |
+
},
|
122 |
+
"optimizer": {
|
123 |
+
"type": "Adam",
|
124 |
+
"args":{
|
125 |
+
"lr": 0.001, // learning rate
|
126 |
+
"weight_decay": 0, // (optional) weight decay
|
127 |
+
"amsgrad": true
|
128 |
+
}
|
129 |
+
},
|
130 |
+
"loss": "nll_loss", // loss
|
131 |
+
"metrics": [
|
132 |
+
"accuracy", "top_k_acc" // list of metrics to evaluate
|
133 |
+
],
|
134 |
+
"lr_scheduler": {
|
135 |
+
"type": "StepLR", // learning rate scheduler
|
136 |
+
"args":{
|
137 |
+
"step_size": 50,
|
138 |
+
"gamma": 0.1
|
139 |
+
}
|
140 |
+
},
|
141 |
+
"trainer": {
|
142 |
+
"epochs": 100, // number of training epochs
|
143 |
+
"save_dir": "saved/", // checkpoints are saved in save_dir/models/name
|
144 |
+
"save_freq": 1, // save checkpoints every save_freq epochs
|
145 |
+
"verbosity": 2, // 0: quiet, 1: per epoch, 2: full
|
146 |
+
|
147 |
+
"monitor": "min val_loss" // mode and metric for model performance monitoring. set 'off' to disable.
|
148 |
+
"early_stop": 10 // number of epochs to wait before early stop. set 0 to disable.
|
149 |
+
|
150 |
+
"tensorboard": true, // enable tensorboard visualization
|
151 |
+
}
|
152 |
+
}
|
153 |
+
```
|
154 |
+
|
155 |
+
Add addional configurations if you need.
|
156 |
+
|
157 |
+
### Using config files
|
158 |
+
Modify the configurations in `.json` config files, then run:
|
159 |
+
|
160 |
+
```
|
161 |
+
python train.py --config config.json
|
162 |
+
```
|
163 |
+
|
164 |
+
### Resuming from checkpoints
|
165 |
+
You can resume from a previously saved checkpoint by:
|
166 |
+
|
167 |
+
```
|
168 |
+
python train.py --resume path/to/checkpoint
|
169 |
+
```
|
170 |
+
|
171 |
+
### Using Multiple GPU
|
172 |
+
You can enable multi-GPU training by setting `n_gpu` argument of the config file to larger number.
|
173 |
+
If configured to use smaller number of gpu than available, first n devices will be used by default.
|
174 |
+
Specify indices of available GPUs by cuda environmental variable.
|
175 |
+
```
|
176 |
+
python train.py --device 2,3 -c config.json
|
177 |
+
```
|
178 |
+
This is equivalent to
|
179 |
+
```
|
180 |
+
CUDA_VISIBLE_DEVICES=2,3 python train.py -c config.py
|
181 |
+
```
|
182 |
+
|
183 |
+
## Customization
|
184 |
+
|
185 |
+
### Project initialization
|
186 |
+
Use the `new_project.py` script to make your new project directory with template files.
|
187 |
+
`python new_project.py ../NewProject` then a new project folder named 'NewProject' will be made.
|
188 |
+
This script will filter out unneccessary files like cache, git files or readme file.
|
189 |
+
|
190 |
+
### Custom CLI options
|
191 |
+
|
192 |
+
Changing values of config file is a clean, safe and easy way of tuning hyperparameters. However, sometimes
|
193 |
+
it is better to have command line options if some values need to be changed too often or quickly.
|
194 |
+
|
195 |
+
This template uses the configurations stored in the json file by default, but by registering custom options as follows
|
196 |
+
you can change some of them using CLI flags.
|
197 |
+
|
198 |
+
```python
|
199 |
+
# simple class-like object having 3 attributes, `flags`, `type`, `target`.
|
200 |
+
CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
|
201 |
+
options = [
|
202 |
+
CustomArgs(['--lr', '--learning_rate'], type=float, target=('optimizer', 'args', 'lr')),
|
203 |
+
CustomArgs(['--bs', '--batch_size'], type=int, target=('data_loader', 'args', 'batch_size'))
|
204 |
+
# options added here can be modified by command line flags.
|
205 |
+
]
|
206 |
+
```
|
207 |
+
`target` argument should be sequence of keys, which are used to access that option in the config dict. In this example, `target`
|
208 |
+
for the learning rate option is `('optimizer', 'args', 'lr')` because `config['optimizer']['args']['lr']` points to the learning rate.
|
209 |
+
`python train.py -c config.json --bs 256` runs training with options given in `config.json` except for the `batch size`
|
210 |
+
which is increased to 256 by command line options.
|
211 |
+
|
212 |
+
|
213 |
+
### Data Loader
|
214 |
+
* **Writing your own data loader**
|
215 |
+
|
216 |
+
1. **Inherit ```BaseDataLoader```**
|
217 |
+
|
218 |
+
`BaseDataLoader` is a subclass of `torch.utils.data.DataLoader`, you can use either of them.
|
219 |
+
|
220 |
+
`BaseDataLoader` handles:
|
221 |
+
* Generating next batch
|
222 |
+
* Data shuffling
|
223 |
+
* Generating validation data loader by calling
|
224 |
+
`BaseDataLoader.split_validation()`
|
225 |
+
|
226 |
+
* **DataLoader Usage**
|
227 |
+
|
228 |
+
`BaseDataLoader` is an iterator, to iterate through batches:
|
229 |
+
```python
|
230 |
+
for batch_idx, (x_batch, y_batch) in data_loader:
|
231 |
+
pass
|
232 |
+
```
|
233 |
+
* **Example**
|
234 |
+
|
235 |
+
Please refer to `data_loader/data_loaders.py` for an MNIST data loading example.
|
236 |
+
|
237 |
+
### Trainer
|
238 |
+
* **Writing your own trainer**
|
239 |
+
|
240 |
+
1. **Inherit ```BaseTrainer```**
|
241 |
+
|
242 |
+
`BaseTrainer` handles:
|
243 |
+
* Training process logging
|
244 |
+
* Checkpoint saving
|
245 |
+
* Checkpoint resuming
|
246 |
+
* Reconfigurable performance monitoring for saving current best model, and early stop training.
|
247 |
+
* If config `monitor` is set to `max val_accuracy`, which means then the trainer will save a checkpoint `model_best.pth` when `validation accuracy` of epoch replaces current `maximum`.
|
248 |
+
* If config `early_stop` is set, training will be automatically terminated when model performance does not improve for given number of epochs. This feature can be turned off by passing 0 to the `early_stop` option, or just deleting the line of config.
|
249 |
+
|
250 |
+
2. **Implementing abstract methods**
|
251 |
+
|
252 |
+
You need to implement `_train_epoch()` for your training process, if you need validation then you can implement `_valid_epoch()` as in `trainer/trainer.py`
|
253 |
+
|
254 |
+
* **Example**
|
255 |
+
|
256 |
+
Please refer to `trainer/trainer.py` for MNIST training.
|
257 |
+
|
258 |
+
* **Iteration-based training**
|
259 |
+
|
260 |
+
`Trainer.__init__` takes an optional argument, `len_epoch` which controls number of batches(steps) in each epoch.
|
261 |
+
|
262 |
+
### Model
|
263 |
+
* **Writing your own model**
|
264 |
+
|
265 |
+
1. **Inherit `BaseModel`**
|
266 |
+
|
267 |
+
`BaseModel` handles:
|
268 |
+
* Inherited from `torch.nn.Module`
|
269 |
+
* `__str__`: Modify native `print` function to prints the number of trainable parameters.
|
270 |
+
|
271 |
+
2. **Implementing abstract methods**
|
272 |
+
|
273 |
+
Implement the foward pass method `forward()`
|
274 |
+
|
275 |
+
* **Example**
|
276 |
+
|
277 |
+
Please refer to `model/model.py` for a LeNet example.
|
278 |
+
|
279 |
+
### Loss
|
280 |
+
Custom loss functions can be implemented in 'model/loss.py'. Use them by changing the name given in "loss" in config file, to corresponding name.
|
281 |
+
|
282 |
+
### Metrics
|
283 |
+
Metric functions are located in 'model/metric.py'.
|
284 |
+
|
285 |
+
You can monitor multiple metrics by providing a list in the configuration file, e.g.:
|
286 |
+
```json
|
287 |
+
"metrics": ["accuracy", "top_k_acc"],
|
288 |
+
```
|
289 |
+
|
290 |
+
### Additional logging
|
291 |
+
If you have additional information to be logged, in `_train_epoch()` of your trainer class, merge them with `log` as shown below before returning:
|
292 |
+
|
293 |
+
```python
|
294 |
+
additional_log = {"gradient_norm": g, "sensitivity": s}
|
295 |
+
log.update(additional_log)
|
296 |
+
return log
|
297 |
+
```
|
298 |
+
|
299 |
+
### Testing
|
300 |
+
You can test trained model by running `test.py` passing path to the trained checkpoint by `--resume` argument.
|
301 |
+
|
302 |
+
### Validation data
|
303 |
+
To split validation data from a data loader, call `BaseDataLoader.split_validation()`, then it will return a data loader for validation of size specified in your config file.
|
304 |
+
The `validation_split` can be a ratio of validation set per total data(0.0 <= float < 1.0), or the number of samples (0 <= int < `n_total_samples`).
|
305 |
+
|
306 |
+
**Note**: the `split_validation()` method will modify the original data loader
|
307 |
+
**Note**: `split_validation()` will return `None` if `"validation_split"` is set to `0`
|
308 |
+
|
309 |
+
### Checkpoints
|
310 |
+
You can specify the name of the training session in config files:
|
311 |
+
```json
|
312 |
+
"name": "MNIST_LeNet",
|
313 |
+
```
|
314 |
+
|
315 |
+
The checkpoints will be saved in `save_dir/name/timestamp/checkpoint_epoch_n`, with timestamp in mmdd_HHMMSS format.
|
316 |
+
|
317 |
+
A copy of config file will be saved in the same folder.
|
318 |
+
|
319 |
+
**Note**: checkpoints contain:
|
320 |
+
```python
|
321 |
+
{
|
322 |
+
'arch': arch,
|
323 |
+
'epoch': epoch,
|
324 |
+
'state_dict': self.model.state_dict(),
|
325 |
+
'optimizer': self.optimizer.state_dict(),
|
326 |
+
'monitor_best': self.mnt_best,
|
327 |
+
'config': self.config
|
328 |
+
}
|
329 |
+
```
|
330 |
+
|
331 |
+
### Tensorboard Visualization
|
332 |
+
This template supports Tensorboard visualization by using either `torch.utils.tensorboard` or [TensorboardX](https://github.com/lanpa/tensorboardX).
|
333 |
+
|
334 |
+
1. **Install**
|
335 |
+
|
336 |
+
If you are using pytorch 1.1 or higher, install tensorboard by 'pip install tensorboard>=1.14.0'.
|
337 |
+
|
338 |
+
Otherwise, you should install tensorboardx. Follow installation guide in [TensorboardX](https://github.com/lanpa/tensorboardX).
|
339 |
+
|
340 |
+
2. **Run training**
|
341 |
+
|
342 |
+
Make sure that `tensorboard` option in the config file is turned on.
|
343 |
+
|
344 |
+
```
|
345 |
+
"tensorboard" : true
|
346 |
+
```
|
347 |
+
|
348 |
+
3. **Open Tensorboard server**
|
349 |
+
|
350 |
+
Type `tensorboard --logdir saved/log/` at the project root, then server will open at `http://localhost:6006`
|
351 |
+
|
352 |
+
By default, values of loss and metrics specified in config file, input images, and histogram of model parameters will be logged.
|
353 |
+
If you need more visualizations, use `add_scalar('tag', data)`, `add_image('tag', image)`, etc in the `trainer._train_epoch` method.
|
354 |
+
`add_something()` methods in this template are basically wrappers for those of `tensorboardX.SummaryWriter` and `torch.utils.tensorboard.SummaryWriter` modules.
|
355 |
+
|
356 |
+
**Note**: You don't have to specify current steps, since `WriterTensorboard` class defined at `logger/visualization.py` will track current steps.
|
357 |
+
|
358 |
+
## Contribution
|
359 |
+
Feel free to contribute any kind of function or enhancement, here the coding style follows PEP8
|
360 |
+
|
361 |
+
Code should pass the [Flake8](http://flake8.pycqa.org/en/latest/) check before committing.
|
362 |
+
|
363 |
+
## TODOs
|
364 |
+
|
365 |
+
- [ ] Multiple optimizers
|
366 |
+
- [ ] Support more tensorboard functions
|
367 |
+
- [x] Using fixed random seed
|
368 |
+
- [x] Support pytorch native tensorboard
|
369 |
+
- [x] `tensorboardX` logger support
|
370 |
+
- [x] Configurable logging layout, checkpoint naming
|
371 |
+
- [x] Iteration-based training (instead of epoch-based)
|
372 |
+
- [x] Adding command line option for fine-tuning
|
373 |
+
|
374 |
+
## License
|
375 |
+
This project is licensed under the MIT License. See LICENSE for more details
|
376 |
+
|
377 |
+
## Acknowledgements
|
378 |
+
This project is inspired by the project [Tensorflow-Project-Template](https://github.com/MrGemy95/Tensorflow-Project-Template) by [Mahmoud Gemy](https://github.com/MrGemy95)
|
model/base/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .base_data_loader import *
|
2 |
+
from .base_model import *
|
3 |
+
from .base_trainer import *
|
model/base/base_data_loader.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
from torch.utils.data.dataloader import default_collate
|
4 |
+
from torch.utils.data.sampler import SubsetRandomSampler
|
5 |
+
|
6 |
+
|
7 |
+
class BaseDataLoader(DataLoader):
|
8 |
+
"""
|
9 |
+
Base class for all data loaders
|
10 |
+
"""
|
11 |
+
def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate):
|
12 |
+
self.validation_split = validation_split
|
13 |
+
self.shuffle = shuffle
|
14 |
+
|
15 |
+
self.batch_idx = 0
|
16 |
+
self.n_samples = len(dataset)
|
17 |
+
|
18 |
+
self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)
|
19 |
+
|
20 |
+
self.init_kwargs = {
|
21 |
+
'dataset': dataset,
|
22 |
+
'batch_size': batch_size,
|
23 |
+
'shuffle': self.shuffle,
|
24 |
+
'collate_fn': collate_fn,
|
25 |
+
'num_workers': num_workers
|
26 |
+
}
|
27 |
+
super().__init__(sampler=self.sampler, **self.init_kwargs)
|
28 |
+
|
29 |
+
def _split_sampler(self, split):
|
30 |
+
if split == 0.0:
|
31 |
+
return None, None
|
32 |
+
|
33 |
+
idx_full = np.arange(self.n_samples)
|
34 |
+
|
35 |
+
np.random.seed(0)
|
36 |
+
np.random.shuffle(idx_full)
|
37 |
+
|
38 |
+
if isinstance(split, int):
|
39 |
+
assert split > 0
|
40 |
+
assert split < self.n_samples, "validation set size is configured to be larger than entire dataset."
|
41 |
+
len_valid = split
|
42 |
+
else:
|
43 |
+
len_valid = int(self.n_samples * split)
|
44 |
+
|
45 |
+
valid_idx = idx_full[0:len_valid]
|
46 |
+
train_idx = np.delete(idx_full, np.arange(0, len_valid))
|
47 |
+
|
48 |
+
train_sampler = SubsetRandomSampler(train_idx)
|
49 |
+
valid_sampler = SubsetRandomSampler(valid_idx)
|
50 |
+
|
51 |
+
# turn off shuffle option which is mutually exclusive with sampler
|
52 |
+
self.shuffle = False
|
53 |
+
self.n_samples = len(train_idx)
|
54 |
+
|
55 |
+
return train_sampler, valid_sampler
|
56 |
+
|
57 |
+
def split_validation(self):
|
58 |
+
if self.valid_sampler is None:
|
59 |
+
return None
|
60 |
+
else:
|
61 |
+
return DataLoader(sampler=self.valid_sampler, **self.init_kwargs)
|
model/base/base_model.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import numpy as np
|
3 |
+
from abc import abstractmethod
|
4 |
+
|
5 |
+
|
6 |
+
class BaseModel(nn.Module):
|
7 |
+
"""
|
8 |
+
Base class for all models
|
9 |
+
"""
|
10 |
+
@abstractmethod
|
11 |
+
def forward(self, *inputs):
|
12 |
+
"""
|
13 |
+
Forward pass logic
|
14 |
+
|
15 |
+
:return: Model output
|
16 |
+
"""
|
17 |
+
raise NotImplementedError
|
18 |
+
|
19 |
+
def __str__(self):
|
20 |
+
"""
|
21 |
+
Model prints with number of trainable parameters
|
22 |
+
"""
|
23 |
+
model_parameters = filter(lambda p: p.requires_grad, self.parameters())
|
24 |
+
params = sum([np.prod(p.size()) for p in model_parameters])
|
25 |
+
return super().__str__() + '\nTrainable parameters: {}'.format(params)
|
model/base/base_trainer.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from abc import abstractmethod
|
3 |
+
from numpy import inf
|
4 |
+
from logger import TensorboardWriter
|
5 |
+
|
6 |
+
|
7 |
+
class BaseTrainer:
|
8 |
+
"""
|
9 |
+
Base class for all trainers
|
10 |
+
"""
|
11 |
+
def __init__(self, model, criterion, metric_ftns, optimizer, config):
|
12 |
+
self.config = config
|
13 |
+
self.logger = config.get_logger('trainer', config['trainer']['verbosity'])
|
14 |
+
|
15 |
+
self.model = model
|
16 |
+
self.criterion = criterion
|
17 |
+
self.metric_ftns = metric_ftns
|
18 |
+
self.optimizer = optimizer
|
19 |
+
|
20 |
+
cfg_trainer = config['trainer']
|
21 |
+
self.epochs = cfg_trainer['epochs']
|
22 |
+
self.save_period = cfg_trainer['save_period']
|
23 |
+
self.monitor = cfg_trainer.get('monitor', 'off')
|
24 |
+
|
25 |
+
# configuration to monitor model performance and save best
|
26 |
+
if self.monitor == 'off':
|
27 |
+
self.mnt_mode = 'off'
|
28 |
+
self.mnt_best = 0
|
29 |
+
else:
|
30 |
+
self.mnt_mode, self.mnt_metric = self.monitor.split()
|
31 |
+
assert self.mnt_mode in ['min', 'max']
|
32 |
+
|
33 |
+
self.mnt_best = inf if self.mnt_mode == 'min' else -inf
|
34 |
+
self.early_stop = cfg_trainer.get('early_stop', inf)
|
35 |
+
if self.early_stop <= 0:
|
36 |
+
self.early_stop = inf
|
37 |
+
|
38 |
+
self.start_epoch = 1
|
39 |
+
|
40 |
+
self.checkpoint_dir = config.save_dir
|
41 |
+
|
42 |
+
# setup visualization writer instance
|
43 |
+
self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard'])
|
44 |
+
|
45 |
+
if config.resume is not None:
|
46 |
+
self._resume_checkpoint(config.resume)
|
47 |
+
|
48 |
+
@abstractmethod
|
49 |
+
def _train_epoch(self, epoch):
|
50 |
+
"""
|
51 |
+
Training logic for an epoch
|
52 |
+
|
53 |
+
:param epoch: Current epoch number
|
54 |
+
"""
|
55 |
+
raise NotImplementedError
|
56 |
+
|
57 |
+
def train(self):
|
58 |
+
"""
|
59 |
+
Full training logic
|
60 |
+
"""
|
61 |
+
not_improved_count = 0
|
62 |
+
for epoch in range(self.start_epoch, self.epochs + 1):
|
63 |
+
result = self._train_epoch(epoch)
|
64 |
+
|
65 |
+
# save logged informations into log dict
|
66 |
+
log = {'epoch': epoch}
|
67 |
+
log.update(result)
|
68 |
+
|
69 |
+
# print logged informations to the screen
|
70 |
+
for key, value in log.items():
|
71 |
+
self.logger.info(' {:15s}: {}'.format(str(key), value))
|
72 |
+
|
73 |
+
# evaluate model performance according to configured metric, save best checkpoint as model_best
|
74 |
+
best = False
|
75 |
+
if self.mnt_mode != 'off':
|
76 |
+
try:
|
77 |
+
# check whether model performance improved or not, according to specified metric(mnt_metric)
|
78 |
+
improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \
|
79 |
+
(self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best)
|
80 |
+
except KeyError:
|
81 |
+
self.logger.warning("Warning: Metric '{}' is not found. "
|
82 |
+
"Model performance monitoring is disabled.".format(self.mnt_metric))
|
83 |
+
self.mnt_mode = 'off'
|
84 |
+
improved = False
|
85 |
+
|
86 |
+
if improved:
|
87 |
+
self.mnt_best = log[self.mnt_metric]
|
88 |
+
not_improved_count = 0
|
89 |
+
best = True
|
90 |
+
else:
|
91 |
+
not_improved_count += 1
|
92 |
+
|
93 |
+
if not_improved_count > self.early_stop:
|
94 |
+
self.logger.info("Validation performance didn\'t improve for {} epochs. "
|
95 |
+
"Training stops.".format(self.early_stop))
|
96 |
+
break
|
97 |
+
|
98 |
+
if epoch % self.save_period == 0:
|
99 |
+
self._save_checkpoint(epoch, save_best=best)
|
100 |
+
|
101 |
+
def _save_checkpoint(self, epoch, save_best=False):
|
102 |
+
"""
|
103 |
+
Saving checkpoints
|
104 |
+
|
105 |
+
:param epoch: current epoch number
|
106 |
+
:param log: logging information of the epoch
|
107 |
+
:param save_best: if True, rename the saved checkpoint to 'model_best.pth'
|
108 |
+
"""
|
109 |
+
arch = type(self.model).__name__
|
110 |
+
state = {
|
111 |
+
'arch': arch,
|
112 |
+
'epoch': epoch,
|
113 |
+
'state_dict': self.model.state_dict(),
|
114 |
+
'optimizer': self.optimizer.state_dict(),
|
115 |
+
'monitor_best': self.mnt_best,
|
116 |
+
'config': self.config
|
117 |
+
}
|
118 |
+
filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch))
|
119 |
+
torch.save(state, filename)
|
120 |
+
self.logger.info("Saving checkpoint: {} ...".format(filename))
|
121 |
+
if save_best:
|
122 |
+
best_path = str(self.checkpoint_dir / 'model_best.pth')
|
123 |
+
torch.save(state, best_path)
|
124 |
+
self.logger.info("Saving current best: model_best.pth ...")
|
125 |
+
|
126 |
+
def _resume_checkpoint(self, resume_path):
|
127 |
+
"""
|
128 |
+
Resume from saved checkpoints
|
129 |
+
|
130 |
+
:param resume_path: Checkpoint path to be resumed
|
131 |
+
"""
|
132 |
+
resume_path = str(resume_path)
|
133 |
+
self.logger.info("Loading checkpoint: {} ...".format(resume_path))
|
134 |
+
checkpoint = torch.load(resume_path)
|
135 |
+
self.start_epoch = checkpoint['epoch'] + 1
|
136 |
+
self.mnt_best = checkpoint['monitor_best']
|
137 |
+
|
138 |
+
# load architecture params from checkpoint.
|
139 |
+
if checkpoint['config']['arch'] != self.config['arch']:
|
140 |
+
self.logger.warning("Warning: Architecture configuration given in config file is different from that of "
|
141 |
+
"checkpoint. This may yield an exception while state_dict is being loaded.")
|
142 |
+
self.model.load_state_dict(checkpoint['state_dict'])
|
143 |
+
|
144 |
+
# load optimizer state from checkpoint only when optimizer type is not changed.
|
145 |
+
if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']:
|
146 |
+
self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. "
|
147 |
+
"Optimizer parameters not being resumed.")
|
148 |
+
else:
|
149 |
+
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
150 |
+
|
151 |
+
self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch))
|
model/config.json
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name": "Mnist_LeNet",
|
3 |
+
"n_gpu": 1,
|
4 |
+
|
5 |
+
"arch": {
|
6 |
+
"type": "MnistModel",
|
7 |
+
"args": {}
|
8 |
+
},
|
9 |
+
"data_loader": {
|
10 |
+
"type": "MnistDataLoader",
|
11 |
+
"args":{
|
12 |
+
"data_dir": "data/",
|
13 |
+
"batch_size": 128,
|
14 |
+
"shuffle": true,
|
15 |
+
"validation_split": 0.1,
|
16 |
+
"num_workers": 2
|
17 |
+
}
|
18 |
+
},
|
19 |
+
"optimizer": {
|
20 |
+
"type": "Adam",
|
21 |
+
"args":{
|
22 |
+
"lr": 0.001,
|
23 |
+
"weight_decay": 0,
|
24 |
+
"amsgrad": true
|
25 |
+
}
|
26 |
+
},
|
27 |
+
"loss": "nll_loss",
|
28 |
+
"metrics": [
|
29 |
+
"accuracy", "top_k_acc"
|
30 |
+
],
|
31 |
+
"lr_scheduler": {
|
32 |
+
"type": "StepLR",
|
33 |
+
"args": {
|
34 |
+
"step_size": 50,
|
35 |
+
"gamma": 0.1
|
36 |
+
}
|
37 |
+
},
|
38 |
+
"trainer": {
|
39 |
+
"epochs": 100,
|
40 |
+
|
41 |
+
"save_dir": "saved/",
|
42 |
+
"save_period": 1,
|
43 |
+
"verbosity": 2,
|
44 |
+
|
45 |
+
"monitor": "min val_loss",
|
46 |
+
"early_stop": 10,
|
47 |
+
|
48 |
+
"tensorboard": true
|
49 |
+
}
|
50 |
+
}
|
model/data_loader/data_loaders.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision import datasets, transforms
|
2 |
+
from base import BaseDataLoader
|
3 |
+
|
4 |
+
|
5 |
+
class MnistDataLoader(BaseDataLoader):
|
6 |
+
"""
|
7 |
+
MNIST data loading demo using BaseDataLoader
|
8 |
+
"""
|
9 |
+
def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_workers=1, training=True):
|
10 |
+
trsfm = transforms.Compose([
|
11 |
+
transforms.ToTensor(),
|
12 |
+
transforms.Normalize((0.1307,), (0.3081,))
|
13 |
+
])
|
14 |
+
self.data_dir = data_dir
|
15 |
+
self.dataset = datasets.MNIST(self.data_dir, train=training, download=True, transform=trsfm)
|
16 |
+
super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers)
|
model/logger/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .logger import *
|
2 |
+
from .visualization import *
|
model/logger/logger.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import logging.config
|
3 |
+
from pathlib import Path
|
4 |
+
from utils import read_json
|
5 |
+
|
6 |
+
|
7 |
+
def setup_logging(save_dir, log_config='logger/logger_config.json', default_level=logging.INFO):
|
8 |
+
"""
|
9 |
+
Setup logging configuration
|
10 |
+
"""
|
11 |
+
log_config = Path(log_config)
|
12 |
+
if log_config.is_file():
|
13 |
+
config = read_json(log_config)
|
14 |
+
# modify logging paths based on run config
|
15 |
+
for _, handler in config['handlers'].items():
|
16 |
+
if 'filename' in handler:
|
17 |
+
handler['filename'] = str(save_dir / handler['filename'])
|
18 |
+
|
19 |
+
logging.config.dictConfig(config)
|
20 |
+
else:
|
21 |
+
print("Warning: logging configuration file is not found in {}.".format(log_config))
|
22 |
+
logging.basicConfig(level=default_level)
|
model/logger/logger_config.json
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
{
|
3 |
+
"version": 1,
|
4 |
+
"disable_existing_loggers": false,
|
5 |
+
"formatters": {
|
6 |
+
"simple": {"format": "%(message)s"},
|
7 |
+
"datetime": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"}
|
8 |
+
},
|
9 |
+
"handlers": {
|
10 |
+
"console": {
|
11 |
+
"class": "logging.StreamHandler",
|
12 |
+
"level": "DEBUG",
|
13 |
+
"formatter": "simple",
|
14 |
+
"stream": "ext://sys.stdout"
|
15 |
+
},
|
16 |
+
"info_file_handler": {
|
17 |
+
"class": "logging.handlers.RotatingFileHandler",
|
18 |
+
"level": "INFO",
|
19 |
+
"formatter": "datetime",
|
20 |
+
"filename": "info.log",
|
21 |
+
"maxBytes": 10485760,
|
22 |
+
"backupCount": 20, "encoding": "utf8"
|
23 |
+
}
|
24 |
+
},
|
25 |
+
"root": {
|
26 |
+
"level": "INFO",
|
27 |
+
"handlers": [
|
28 |
+
"console",
|
29 |
+
"info_file_handler"
|
30 |
+
]
|
31 |
+
}
|
32 |
+
}
|
model/logger/visualization.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from datetime import datetime
|
3 |
+
|
4 |
+
|
5 |
+
class TensorboardWriter():
|
6 |
+
def __init__(self, log_dir, logger, enabled):
|
7 |
+
self.writer = None
|
8 |
+
self.selected_module = ""
|
9 |
+
|
10 |
+
if enabled:
|
11 |
+
log_dir = str(log_dir)
|
12 |
+
|
13 |
+
# Retrieve vizualization writer.
|
14 |
+
succeeded = False
|
15 |
+
for module in ["torch.utils.tensorboard", "tensorboardX"]:
|
16 |
+
try:
|
17 |
+
self.writer = importlib.import_module(module).SummaryWriter(log_dir)
|
18 |
+
succeeded = True
|
19 |
+
break
|
20 |
+
except ImportError:
|
21 |
+
succeeded = False
|
22 |
+
self.selected_module = module
|
23 |
+
|
24 |
+
if not succeeded:
|
25 |
+
message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \
|
26 |
+
"this machine. Please install TensorboardX with 'pip install tensorboardx', upgrade PyTorch to " \
|
27 |
+
"version >= 1.1 to use 'torch.utils.tensorboard' or turn off the option in the 'config.json' file."
|
28 |
+
logger.warning(message)
|
29 |
+
|
30 |
+
self.step = 0
|
31 |
+
self.mode = ''
|
32 |
+
|
33 |
+
self.tb_writer_ftns = {
|
34 |
+
'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio',
|
35 |
+
'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding'
|
36 |
+
}
|
37 |
+
self.tag_mode_exceptions = {'add_histogram', 'add_embedding'}
|
38 |
+
self.timer = datetime.now()
|
39 |
+
|
40 |
+
def set_step(self, step, mode='train'):
|
41 |
+
self.mode = mode
|
42 |
+
self.step = step
|
43 |
+
if step == 0:
|
44 |
+
self.timer = datetime.now()
|
45 |
+
else:
|
46 |
+
duration = datetime.now() - self.timer
|
47 |
+
self.add_scalar('steps_per_sec', 1 / duration.total_seconds())
|
48 |
+
self.timer = datetime.now()
|
49 |
+
|
50 |
+
def __getattr__(self, name):
|
51 |
+
"""
|
52 |
+
If visualization is configured to use:
|
53 |
+
return add_data() methods of tensorboard with additional information (step, tag) added.
|
54 |
+
Otherwise:
|
55 |
+
return a blank function handle that does nothing
|
56 |
+
"""
|
57 |
+
if name in self.tb_writer_ftns:
|
58 |
+
add_data = getattr(self.writer, name, None)
|
59 |
+
|
60 |
+
def wrapper(tag, data, *args, **kwargs):
|
61 |
+
if add_data is not None:
|
62 |
+
# add mode(train/valid) tag
|
63 |
+
if name not in self.tag_mode_exceptions:
|
64 |
+
tag = '{}/{}'.format(tag, self.mode)
|
65 |
+
add_data(tag, data, self.step, *args, **kwargs)
|
66 |
+
return wrapper
|
67 |
+
else:
|
68 |
+
# default action for returning methods defined in this class, set_step() for instance.
|
69 |
+
try:
|
70 |
+
attr = object.__getattr__(name)
|
71 |
+
except AttributeError:
|
72 |
+
raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name))
|
73 |
+
return attr
|
model/model/loss.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn.functional as F
|
2 |
+
|
3 |
+
|
4 |
+
def nll_loss(output, target):
|
5 |
+
return F.nll_loss(output, target)
|
model/model/metric.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def accuracy(output, target):
|
5 |
+
with torch.no_grad():
|
6 |
+
pred = torch.argmax(output, dim=1)
|
7 |
+
assert pred.shape[0] == len(target)
|
8 |
+
correct = 0
|
9 |
+
correct += torch.sum(pred == target).item()
|
10 |
+
return correct / len(target)
|
11 |
+
|
12 |
+
|
13 |
+
def top_k_acc(output, target, k=3):
|
14 |
+
with torch.no_grad():
|
15 |
+
pred = torch.topk(output, k, dim=1)[1]
|
16 |
+
assert pred.shape[0] == len(target)
|
17 |
+
correct = 0
|
18 |
+
for i in range(k):
|
19 |
+
correct += torch.sum(pred[:, i] == target).item()
|
20 |
+
return correct / len(target)
|
model/model/model.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from base import BaseModel
|
4 |
+
|
5 |
+
|
6 |
+
class MnistModel(BaseModel):
|
7 |
+
def __init__(self, num_classes=10):
|
8 |
+
super().__init__()
|
9 |
+
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
10 |
+
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
11 |
+
self.conv2_drop = nn.Dropout2d()
|
12 |
+
self.fc1 = nn.Linear(320, 50)
|
13 |
+
self.fc2 = nn.Linear(50, num_classes)
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
x = F.relu(F.max_pool2d(self.conv1(x), 2))
|
17 |
+
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
|
18 |
+
x = x.view(-1, 320)
|
19 |
+
x = F.relu(self.fc1(x))
|
20 |
+
x = F.dropout(x, training=self.training)
|
21 |
+
x = self.fc2(x)
|
22 |
+
return F.log_softmax(x, dim=1)
|
model/new_project.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from pathlib import Path
|
3 |
+
from shutil import copytree, ignore_patterns
|
4 |
+
|
5 |
+
|
6 |
+
# This script initializes new pytorch project with the template files.
|
7 |
+
# Run `python3 new_project.py ../MyNewProject` then new project named
|
8 |
+
# MyNewProject will be made
|
9 |
+
current_dir = Path()
|
10 |
+
assert (current_dir / 'new_project.py').is_file(), 'Script should be executed in the pytorch-template directory'
|
11 |
+
assert len(sys.argv) == 2, 'Specify a name for the new project. Example: python3 new_project.py MyNewProject'
|
12 |
+
|
13 |
+
project_name = Path(sys.argv[1])
|
14 |
+
target_dir = current_dir / project_name
|
15 |
+
|
16 |
+
ignore = [".git", "data", "saved", "new_project.py", "LICENSE", ".flake8", "README.md", "__pycache__"]
|
17 |
+
copytree(current_dir, target_dir, ignore=ignore_patterns(*ignore))
|
18 |
+
print('New project initialized at', target_dir.absolute().resolve())
|
model/parse_config.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
from pathlib import Path
|
4 |
+
from functools import reduce, partial
|
5 |
+
from operator import getitem
|
6 |
+
from datetime import datetime
|
7 |
+
from logger import setup_logging
|
8 |
+
from utils import read_json, write_json
|
9 |
+
|
10 |
+
|
11 |
+
class ConfigParser:
|
12 |
+
def __init__(self, config, resume=None, modification=None, run_id=None):
|
13 |
+
"""
|
14 |
+
class to parse configuration json file. Handles hyperparameters for training, initializations of modules, checkpoint saving
|
15 |
+
and logging module.
|
16 |
+
:param config: Dict containing configurations, hyperparameters for training. contents of `config.json` file for example.
|
17 |
+
:param resume: String, path to the checkpoint being loaded.
|
18 |
+
:param modification: Dict keychain:value, specifying position values to be replaced from config dict.
|
19 |
+
:param run_id: Unique Identifier for training processes. Used to save checkpoints and training log. Timestamp is being used as default
|
20 |
+
"""
|
21 |
+
# load config file and apply modification
|
22 |
+
self._config = _update_config(config, modification)
|
23 |
+
self.resume = resume
|
24 |
+
|
25 |
+
# set save_dir where trained model and log will be saved.
|
26 |
+
save_dir = Path(self.config['trainer']['save_dir'])
|
27 |
+
|
28 |
+
exper_name = self.config['name']
|
29 |
+
if run_id is None: # use timestamp as default run-id
|
30 |
+
run_id = datetime.now().strftime(r'%m%d_%H%M%S')
|
31 |
+
self._save_dir = save_dir / 'models' / exper_name / run_id
|
32 |
+
self._log_dir = save_dir / 'log' / exper_name / run_id
|
33 |
+
|
34 |
+
# make directory for saving checkpoints and log.
|
35 |
+
exist_ok = run_id == ''
|
36 |
+
self.save_dir.mkdir(parents=True, exist_ok=exist_ok)
|
37 |
+
self.log_dir.mkdir(parents=True, exist_ok=exist_ok)
|
38 |
+
|
39 |
+
# save updated config file to the checkpoint dir
|
40 |
+
write_json(self.config, self.save_dir / 'config.json')
|
41 |
+
|
42 |
+
# configure logging module
|
43 |
+
setup_logging(self.log_dir)
|
44 |
+
self.log_levels = {
|
45 |
+
0: logging.WARNING,
|
46 |
+
1: logging.INFO,
|
47 |
+
2: logging.DEBUG
|
48 |
+
}
|
49 |
+
|
50 |
+
@classmethod
|
51 |
+
def from_args(cls, args, options=''):
|
52 |
+
"""
|
53 |
+
Initialize this class from some cli arguments. Used in train, test.
|
54 |
+
"""
|
55 |
+
for opt in options:
|
56 |
+
args.add_argument(*opt.flags, default=None, type=opt.type)
|
57 |
+
if not isinstance(args, tuple):
|
58 |
+
args = args.parse_args()
|
59 |
+
|
60 |
+
if args.device is not None:
|
61 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
|
62 |
+
if args.resume is not None:
|
63 |
+
resume = Path(args.resume)
|
64 |
+
cfg_fname = resume.parent / 'config.json'
|
65 |
+
else:
|
66 |
+
msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example."
|
67 |
+
assert args.config is not None, msg_no_cfg
|
68 |
+
resume = None
|
69 |
+
cfg_fname = Path(args.config)
|
70 |
+
|
71 |
+
config = read_json(cfg_fname)
|
72 |
+
if args.config and resume:
|
73 |
+
# update new config for fine-tuning
|
74 |
+
config.update(read_json(args.config))
|
75 |
+
|
76 |
+
# parse custom cli options into dictionary
|
77 |
+
modification = {opt.target : getattr(args, _get_opt_name(opt.flags)) for opt in options}
|
78 |
+
return cls(config, resume, modification)
|
79 |
+
|
80 |
+
def init_obj(self, name, module, *args, **kwargs):
|
81 |
+
"""
|
82 |
+
Finds a function handle with the name given as 'type' in config, and returns the
|
83 |
+
instance initialized with corresponding arguments given.
|
84 |
+
|
85 |
+
`object = config.init_obj('name', module, a, b=1)`
|
86 |
+
is equivalent to
|
87 |
+
`object = module.name(a, b=1)`
|
88 |
+
"""
|
89 |
+
module_name = self[name]['type']
|
90 |
+
module_args = dict(self[name]['args'])
|
91 |
+
assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
|
92 |
+
module_args.update(kwargs)
|
93 |
+
return getattr(module, module_name)(*args, **module_args)
|
94 |
+
|
95 |
+
def init_ftn(self, name, module, *args, **kwargs):
|
96 |
+
"""
|
97 |
+
Finds a function handle with the name given as 'type' in config, and returns the
|
98 |
+
function with given arguments fixed with functools.partial.
|
99 |
+
|
100 |
+
`function = config.init_ftn('name', module, a, b=1)`
|
101 |
+
is equivalent to
|
102 |
+
`function = lambda *args, **kwargs: module.name(a, *args, b=1, **kwargs)`.
|
103 |
+
"""
|
104 |
+
module_name = self[name]['type']
|
105 |
+
module_args = dict(self[name]['args'])
|
106 |
+
assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
|
107 |
+
module_args.update(kwargs)
|
108 |
+
return partial(getattr(module, module_name), *args, **module_args)
|
109 |
+
|
110 |
+
def __getitem__(self, name):
|
111 |
+
"""Access items like ordinary dict."""
|
112 |
+
return self.config[name]
|
113 |
+
|
114 |
+
def get_logger(self, name, verbosity=2):
|
115 |
+
msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity, self.log_levels.keys())
|
116 |
+
assert verbosity in self.log_levels, msg_verbosity
|
117 |
+
logger = logging.getLogger(name)
|
118 |
+
logger.setLevel(self.log_levels[verbosity])
|
119 |
+
return logger
|
120 |
+
|
121 |
+
# setting read-only attributes
|
122 |
+
@property
|
123 |
+
def config(self):
|
124 |
+
return self._config
|
125 |
+
|
126 |
+
@property
|
127 |
+
def save_dir(self):
|
128 |
+
return self._save_dir
|
129 |
+
|
130 |
+
@property
|
131 |
+
def log_dir(self):
|
132 |
+
return self._log_dir
|
133 |
+
|
134 |
+
# helper functions to update config dict with custom cli options
|
135 |
+
def _update_config(config, modification):
|
136 |
+
if modification is None:
|
137 |
+
return config
|
138 |
+
|
139 |
+
for k, v in modification.items():
|
140 |
+
if v is not None:
|
141 |
+
_set_by_path(config, k, v)
|
142 |
+
return config
|
143 |
+
|
144 |
+
def _get_opt_name(flags):
|
145 |
+
for flg in flags:
|
146 |
+
if flg.startswith('--'):
|
147 |
+
return flg.replace('--', '')
|
148 |
+
return flags[0].replace('--', '')
|
149 |
+
|
150 |
+
def _set_by_path(tree, keys, value):
|
151 |
+
"""Set a value in a nested object in tree by sequence of keys."""
|
152 |
+
keys = keys.split(';')
|
153 |
+
_get_by_path(tree, keys[:-1])[keys[-1]] = value
|
154 |
+
|
155 |
+
def _get_by_path(tree, keys):
|
156 |
+
"""Access a nested object in tree by sequence of keys."""
|
157 |
+
return reduce(getitem, keys, tree)
|
model/requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=1.1
|
2 |
+
torchvision
|
3 |
+
numpy
|
4 |
+
tqdm
|
5 |
+
tensorboard>=1.14
|
model/test.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
import data_loader.data_loaders as module_data
|
5 |
+
import model.loss as module_loss
|
6 |
+
import model.metric as module_metric
|
7 |
+
import model.model as module_arch
|
8 |
+
from parse_config import ConfigParser
|
9 |
+
|
10 |
+
|
11 |
+
def main(config):
|
12 |
+
logger = config.get_logger('test')
|
13 |
+
|
14 |
+
# setup data_loader instances
|
15 |
+
data_loader = getattr(module_data, config['data_loader']['type'])(
|
16 |
+
config['data_loader']['args']['data_dir'],
|
17 |
+
batch_size=512,
|
18 |
+
shuffle=False,
|
19 |
+
validation_split=0.0,
|
20 |
+
training=False,
|
21 |
+
num_workers=2
|
22 |
+
)
|
23 |
+
|
24 |
+
# build model architecture
|
25 |
+
model = config.init_obj('arch', module_arch)
|
26 |
+
logger.info(model)
|
27 |
+
|
28 |
+
# get function handles of loss and metrics
|
29 |
+
loss_fn = getattr(module_loss, config['loss'])
|
30 |
+
metric_fns = [getattr(module_metric, met) for met in config['metrics']]
|
31 |
+
|
32 |
+
logger.info('Loading checkpoint: {} ...'.format(config.resume))
|
33 |
+
checkpoint = torch.load(config.resume)
|
34 |
+
state_dict = checkpoint['state_dict']
|
35 |
+
if config['n_gpu'] > 1:
|
36 |
+
model = torch.nn.DataParallel(model)
|
37 |
+
model.load_state_dict(state_dict)
|
38 |
+
|
39 |
+
# prepare model for testing
|
40 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
41 |
+
model = model.to(device)
|
42 |
+
model.eval()
|
43 |
+
|
44 |
+
total_loss = 0.0
|
45 |
+
total_metrics = torch.zeros(len(metric_fns))
|
46 |
+
|
47 |
+
with torch.no_grad():
|
48 |
+
for i, (data, target) in enumerate(tqdm(data_loader)):
|
49 |
+
data, target = data.to(device), target.to(device)
|
50 |
+
output = model(data)
|
51 |
+
|
52 |
+
#
|
53 |
+
# save sample images, or do something with output here
|
54 |
+
#
|
55 |
+
|
56 |
+
# computing loss, metrics on test set
|
57 |
+
loss = loss_fn(output, target)
|
58 |
+
batch_size = data.shape[0]
|
59 |
+
total_loss += loss.item() * batch_size
|
60 |
+
for i, metric in enumerate(metric_fns):
|
61 |
+
total_metrics[i] += metric(output, target) * batch_size
|
62 |
+
|
63 |
+
n_samples = len(data_loader.sampler)
|
64 |
+
log = {'loss': total_loss / n_samples}
|
65 |
+
log.update({
|
66 |
+
met.__name__: total_metrics[i].item() / n_samples for i, met in enumerate(metric_fns)
|
67 |
+
})
|
68 |
+
logger.info(log)
|
69 |
+
|
70 |
+
|
71 |
+
if __name__ == '__main__':
|
72 |
+
args = argparse.ArgumentParser(description='PyTorch Template')
|
73 |
+
args.add_argument('-c', '--config', default=None, type=str,
|
74 |
+
help='config file path (default: None)')
|
75 |
+
args.add_argument('-r', '--resume', default=None, type=str,
|
76 |
+
help='path to latest checkpoint (default: None)')
|
77 |
+
args.add_argument('-d', '--device', default=None, type=str,
|
78 |
+
help='indices of GPUs to enable (default: all)')
|
79 |
+
|
80 |
+
config = ConfigParser.from_args(args)
|
81 |
+
main(config)
|
model/train.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import collections
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import data_loader.data_loaders as module_data
|
6 |
+
import model.loss as module_loss
|
7 |
+
import model.metric as module_metric
|
8 |
+
import model.model as module_arch
|
9 |
+
from parse_config import ConfigParser
|
10 |
+
from trainer import Trainer
|
11 |
+
from utils import prepare_device
|
12 |
+
|
13 |
+
|
14 |
+
# fix random seeds for reproducibility
|
15 |
+
SEED = 123
|
16 |
+
torch.manual_seed(SEED)
|
17 |
+
torch.backends.cudnn.deterministic = True
|
18 |
+
torch.backends.cudnn.benchmark = False
|
19 |
+
np.random.seed(SEED)
|
20 |
+
|
21 |
+
def main(config):
|
22 |
+
logger = config.get_logger('train')
|
23 |
+
|
24 |
+
# setup data_loader instances
|
25 |
+
data_loader = config.init_obj('data_loader', module_data)
|
26 |
+
valid_data_loader = data_loader.split_validation()
|
27 |
+
|
28 |
+
# build model architecture, then print to console
|
29 |
+
model = config.init_obj('arch', module_arch)
|
30 |
+
logger.info(model)
|
31 |
+
|
32 |
+
# prepare for (multi-device) GPU training
|
33 |
+
device, device_ids = prepare_device(config['n_gpu'])
|
34 |
+
model = model.to(device)
|
35 |
+
if len(device_ids) > 1:
|
36 |
+
model = torch.nn.DataParallel(model, device_ids=device_ids)
|
37 |
+
|
38 |
+
# get function handles of loss and metrics
|
39 |
+
criterion = getattr(module_loss, config['loss'])
|
40 |
+
metrics = [getattr(module_metric, met) for met in config['metrics']]
|
41 |
+
|
42 |
+
# build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
|
43 |
+
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
|
44 |
+
optimizer = config.init_obj('optimizer', torch.optim, trainable_params)
|
45 |
+
lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer)
|
46 |
+
|
47 |
+
trainer = Trainer(model, criterion, metrics, optimizer,
|
48 |
+
config=config,
|
49 |
+
device=device,
|
50 |
+
data_loader=data_loader,
|
51 |
+
valid_data_loader=valid_data_loader,
|
52 |
+
lr_scheduler=lr_scheduler)
|
53 |
+
|
54 |
+
trainer.train()
|
55 |
+
|
56 |
+
|
57 |
+
if __name__ == '__main__':
|
58 |
+
args = argparse.ArgumentParser(description='PyTorch Template')
|
59 |
+
args.add_argument('-c', '--config', default=None, type=str,
|
60 |
+
help='config file path (default: None)')
|
61 |
+
args.add_argument('-r', '--resume', default=None, type=str,
|
62 |
+
help='path to latest checkpoint (default: None)')
|
63 |
+
args.add_argument('-d', '--device', default=None, type=str,
|
64 |
+
help='indices of GPUs to enable (default: all)')
|
65 |
+
|
66 |
+
# custom cli options to modify configuration from default values given in json file.
|
67 |
+
CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
|
68 |
+
options = [
|
69 |
+
CustomArgs(['--lr', '--learning_rate'], type=float, target='optimizer;args;lr'),
|
70 |
+
CustomArgs(['--bs', '--batch_size'], type=int, target='data_loader;args;batch_size')
|
71 |
+
]
|
72 |
+
config = ConfigParser.from_args(args, options)
|
73 |
+
main(config)
|
model/trainer/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .trainer import *
|
model/trainer/trainer.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from torchvision.utils import make_grid
|
4 |
+
from base import BaseTrainer
|
5 |
+
from utils import inf_loop, MetricTracker
|
6 |
+
|
7 |
+
|
8 |
+
class Trainer(BaseTrainer):
|
9 |
+
"""
|
10 |
+
Trainer class
|
11 |
+
"""
|
12 |
+
def __init__(self, model, criterion, metric_ftns, optimizer, config, device,
|
13 |
+
data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None):
|
14 |
+
super().__init__(model, criterion, metric_ftns, optimizer, config)
|
15 |
+
self.config = config
|
16 |
+
self.device = device
|
17 |
+
self.data_loader = data_loader
|
18 |
+
if len_epoch is None:
|
19 |
+
# epoch-based training
|
20 |
+
self.len_epoch = len(self.data_loader)
|
21 |
+
else:
|
22 |
+
# iteration-based training
|
23 |
+
self.data_loader = inf_loop(data_loader)
|
24 |
+
self.len_epoch = len_epoch
|
25 |
+
self.valid_data_loader = valid_data_loader
|
26 |
+
self.do_validation = self.valid_data_loader is not None
|
27 |
+
self.lr_scheduler = lr_scheduler
|
28 |
+
self.log_step = int(np.sqrt(data_loader.batch_size))
|
29 |
+
|
30 |
+
self.train_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)
|
31 |
+
self.valid_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)
|
32 |
+
|
33 |
+
def _train_epoch(self, epoch):
|
34 |
+
"""
|
35 |
+
Training logic for an epoch
|
36 |
+
|
37 |
+
:param epoch: Integer, current training epoch.
|
38 |
+
:return: A log that contains average loss and metric in this epoch.
|
39 |
+
"""
|
40 |
+
self.model.train()
|
41 |
+
self.train_metrics.reset()
|
42 |
+
for batch_idx, (data, target) in enumerate(self.data_loader):
|
43 |
+
data, target = data.to(self.device), target.to(self.device)
|
44 |
+
|
45 |
+
self.optimizer.zero_grad()
|
46 |
+
output = self.model(data)
|
47 |
+
loss = self.criterion(output, target)
|
48 |
+
loss.backward()
|
49 |
+
self.optimizer.step()
|
50 |
+
|
51 |
+
self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
|
52 |
+
self.train_metrics.update('loss', loss.item())
|
53 |
+
for met in self.metric_ftns:
|
54 |
+
self.train_metrics.update(met.__name__, met(output, target))
|
55 |
+
|
56 |
+
if batch_idx % self.log_step == 0:
|
57 |
+
self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
|
58 |
+
epoch,
|
59 |
+
self._progress(batch_idx),
|
60 |
+
loss.item()))
|
61 |
+
self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
|
62 |
+
|
63 |
+
if batch_idx == self.len_epoch:
|
64 |
+
break
|
65 |
+
log = self.train_metrics.result()
|
66 |
+
|
67 |
+
if self.do_validation:
|
68 |
+
val_log = self._valid_epoch(epoch)
|
69 |
+
log.update(**{'val_'+k : v for k, v in val_log.items()})
|
70 |
+
|
71 |
+
if self.lr_scheduler is not None:
|
72 |
+
self.lr_scheduler.step()
|
73 |
+
return log
|
74 |
+
|
75 |
+
def _valid_epoch(self, epoch):
|
76 |
+
"""
|
77 |
+
Validate after training an epoch
|
78 |
+
|
79 |
+
:param epoch: Integer, current training epoch.
|
80 |
+
:return: A log that contains information about validation
|
81 |
+
"""
|
82 |
+
self.model.eval()
|
83 |
+
self.valid_metrics.reset()
|
84 |
+
with torch.no_grad():
|
85 |
+
for batch_idx, (data, target) in enumerate(self.valid_data_loader):
|
86 |
+
data, target = data.to(self.device), target.to(self.device)
|
87 |
+
|
88 |
+
output = self.model(data)
|
89 |
+
loss = self.criterion(output, target)
|
90 |
+
|
91 |
+
self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid')
|
92 |
+
self.valid_metrics.update('loss', loss.item())
|
93 |
+
for met in self.metric_ftns:
|
94 |
+
self.valid_metrics.update(met.__name__, met(output, target))
|
95 |
+
self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
|
96 |
+
|
97 |
+
# add histogram of model parameters to the tensorboard
|
98 |
+
for name, p in self.model.named_parameters():
|
99 |
+
self.writer.add_histogram(name, p, bins='auto')
|
100 |
+
return self.valid_metrics.result()
|
101 |
+
|
102 |
+
def _progress(self, batch_idx):
|
103 |
+
base = '[{}/{} ({:.0f}%)]'
|
104 |
+
if hasattr(self.data_loader, 'n_samples'):
|
105 |
+
current = batch_idx * self.data_loader.batch_size
|
106 |
+
total = self.data_loader.n_samples
|
107 |
+
else:
|
108 |
+
current = batch_idx
|
109 |
+
total = self.len_epoch
|
110 |
+
return base.format(current, total, 100.0 * current / total)
|
model/utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .util import *
|
model/utils/util.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import torch
|
3 |
+
import pandas as pd
|
4 |
+
from pathlib import Path
|
5 |
+
from itertools import repeat
|
6 |
+
from collections import OrderedDict
|
7 |
+
|
8 |
+
|
9 |
+
def ensure_dir(dirname):
|
10 |
+
dirname = Path(dirname)
|
11 |
+
if not dirname.is_dir():
|
12 |
+
dirname.mkdir(parents=True, exist_ok=False)
|
13 |
+
|
14 |
+
def read_json(fname):
|
15 |
+
fname = Path(fname)
|
16 |
+
with fname.open('rt') as handle:
|
17 |
+
return json.load(handle, object_hook=OrderedDict)
|
18 |
+
|
19 |
+
def write_json(content, fname):
|
20 |
+
fname = Path(fname)
|
21 |
+
with fname.open('wt') as handle:
|
22 |
+
json.dump(content, handle, indent=4, sort_keys=False)
|
23 |
+
|
24 |
+
def inf_loop(data_loader):
|
25 |
+
''' wrapper function for endless data loader. '''
|
26 |
+
for loader in repeat(data_loader):
|
27 |
+
yield from loader
|
28 |
+
|
29 |
+
def prepare_device(n_gpu_use):
|
30 |
+
"""
|
31 |
+
setup GPU device if available. get gpu device indices which are used for DataParallel
|
32 |
+
"""
|
33 |
+
n_gpu = torch.cuda.device_count()
|
34 |
+
if n_gpu_use > 0 and n_gpu == 0:
|
35 |
+
print("Warning: There\'s no GPU available on this machine,"
|
36 |
+
"training will be performed on CPU.")
|
37 |
+
n_gpu_use = 0
|
38 |
+
if n_gpu_use > n_gpu:
|
39 |
+
print(f"Warning: The number of GPU\'s configured to use is {n_gpu_use}, but only {n_gpu} are "
|
40 |
+
"available on this machine.")
|
41 |
+
n_gpu_use = n_gpu
|
42 |
+
device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
|
43 |
+
list_ids = list(range(n_gpu_use))
|
44 |
+
return device, list_ids
|
45 |
+
|
46 |
+
class MetricTracker:
|
47 |
+
def __init__(self, *keys, writer=None):
|
48 |
+
self.writer = writer
|
49 |
+
self._data = pd.DataFrame(index=keys, columns=['total', 'counts', 'average'])
|
50 |
+
self.reset()
|
51 |
+
|
52 |
+
def reset(self):
|
53 |
+
for col in self._data.columns:
|
54 |
+
self._data[col].values[:] = 0
|
55 |
+
|
56 |
+
def update(self, key, value, n=1):
|
57 |
+
if self.writer is not None:
|
58 |
+
self.writer.add_scalar(key, value)
|
59 |
+
self._data.total[key] += value * n
|
60 |
+
self._data.counts[key] += n
|
61 |
+
self._data.average[key] = self._data.total[key] / self._data.counts[key]
|
62 |
+
|
63 |
+
def avg(self, key):
|
64 |
+
return self._data.average[key]
|
65 |
+
|
66 |
+
def result(self):
|
67 |
+
return dict(self._data.average)
|