Hannes Kuchelmeister commited on
Commit
b72a776
1 Parent(s): 3d5c288

add model to repository

Browse files
model/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2018 Victor Huang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
model/README.md ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PyTorch Template Project
2
+ PyTorch deep learning project made easy.
3
+
4
+ <!-- @import "[TOC]" {cmd="toc" depthFrom=1 depthTo=6 orderedList=false} -->
5
+
6
+ <!-- code_chunk_output -->
7
+
8
+ * [PyTorch Template Project](#pytorch-template-project)
9
+ * [Requirements](#requirements)
10
+ * [Features](#features)
11
+ * [Folder Structure](#folder-structure)
12
+ * [Usage](#usage)
13
+ * [Config file format](#config-file-format)
14
+ * [Using config files](#using-config-files)
15
+ * [Resuming from checkpoints](#resuming-from-checkpoints)
16
+ * [Using Multiple GPU](#using-multiple-gpu)
17
+ * [Customization](#customization)
18
+ * [Custom CLI options](#custom-cli-options)
19
+ * [Data Loader](#data-loader)
20
+ * [Trainer](#trainer)
21
+ * [Model](#model)
22
+ * [Loss](#loss)
23
+ * [metrics](#metrics)
24
+ * [Additional logging](#additional-logging)
25
+ * [Validation data](#validation-data)
26
+ * [Checkpoints](#checkpoints)
27
+ * [Tensorboard Visualization](#tensorboard-visualization)
28
+ * [Contribution](#contribution)
29
+ * [TODOs](#todos)
30
+ * [License](#license)
31
+ * [Acknowledgements](#acknowledgements)
32
+
33
+ <!-- /code_chunk_output -->
34
+
35
+ ## Requirements
36
+ * Python >= 3.5 (3.6 recommended)
37
+ * PyTorch >= 0.4 (1.2 recommended)
38
+ * tqdm (Optional for `test.py`)
39
+ * tensorboard >= 1.14 (see [Tensorboard Visualization](#tensorboard-visualization))
40
+
41
+ ## Features
42
+ * Clear folder structure which is suitable for many deep learning projects.
43
+ * `.json` config file support for convenient parameter tuning.
44
+ * Customizable command line options for more convenient parameter tuning.
45
+ * Checkpoint saving and resuming.
46
+ * Abstract base classes for faster development:
47
+ * `BaseTrainer` handles checkpoint saving/resuming, training process logging, and more.
48
+ * `BaseDataLoader` handles batch generation, data shuffling, and validation data splitting.
49
+ * `BaseModel` provides basic model summary.
50
+
51
+ ## Folder Structure
52
+ ```
53
+ pytorch-template/
54
+
55
+ ├── train.py - main script to start training
56
+ ├── test.py - evaluation of trained model
57
+
58
+ ├── config.json - holds configuration for training
59
+ ├── parse_config.py - class to handle config file and cli options
60
+
61
+ ├── new_project.py - initialize new project with template files
62
+
63
+ ├── base/ - abstract base classes
64
+ │ ├── base_data_loader.py
65
+ │ ├── base_model.py
66
+ │ └── base_trainer.py
67
+
68
+ ├── data_loader/ - anything about data loading goes here
69
+ │ └── data_loaders.py
70
+
71
+ ├── data/ - default directory for storing input data
72
+
73
+ ├── model/ - models, losses, and metrics
74
+ │ ├── model.py
75
+ │ ├── metric.py
76
+ │ └── loss.py
77
+
78
+ ├── saved/
79
+ │ ├── models/ - trained models are saved here
80
+ │ └── log/ - default logdir for tensorboard and logging output
81
+
82
+ ├── trainer/ - trainers
83
+ │ └── trainer.py
84
+
85
+ ├── logger/ - module for tensorboard visualization and logging
86
+ │ ├── visualization.py
87
+ │ ├── logger.py
88
+ │ └── logger_config.json
89
+
90
+ └── utils/ - small utility functions
91
+ ├── util.py
92
+ └── ...
93
+ ```
94
+
95
+ ## Usage
96
+ The code in this repo is an MNIST example of the template.
97
+ Try `python train.py -c config.json` to run code.
98
+
99
+ ### Config file format
100
+ Config files are in `.json` format:
101
+ ```javascript
102
+ {
103
+ "name": "Mnist_LeNet", // training session name
104
+ "n_gpu": 1, // number of GPUs to use for training.
105
+
106
+ "arch": {
107
+ "type": "MnistModel", // name of model architecture to train
108
+ "args": {
109
+
110
+ }
111
+ },
112
+ "data_loader": {
113
+ "type": "MnistDataLoader", // selecting data loader
114
+ "args":{
115
+ "data_dir": "data/", // dataset path
116
+ "batch_size": 64, // batch size
117
+ "shuffle": true, // shuffle training data before splitting
118
+ "validation_split": 0.1 // size of validation dataset. float(portion) or int(number of samples)
119
+ "num_workers": 2, // number of cpu processes to be used for data loading
120
+ }
121
+ },
122
+ "optimizer": {
123
+ "type": "Adam",
124
+ "args":{
125
+ "lr": 0.001, // learning rate
126
+ "weight_decay": 0, // (optional) weight decay
127
+ "amsgrad": true
128
+ }
129
+ },
130
+ "loss": "nll_loss", // loss
131
+ "metrics": [
132
+ "accuracy", "top_k_acc" // list of metrics to evaluate
133
+ ],
134
+ "lr_scheduler": {
135
+ "type": "StepLR", // learning rate scheduler
136
+ "args":{
137
+ "step_size": 50,
138
+ "gamma": 0.1
139
+ }
140
+ },
141
+ "trainer": {
142
+ "epochs": 100, // number of training epochs
143
+ "save_dir": "saved/", // checkpoints are saved in save_dir/models/name
144
+ "save_freq": 1, // save checkpoints every save_freq epochs
145
+ "verbosity": 2, // 0: quiet, 1: per epoch, 2: full
146
+
147
+ "monitor": "min val_loss" // mode and metric for model performance monitoring. set 'off' to disable.
148
+ "early_stop": 10 // number of epochs to wait before early stop. set 0 to disable.
149
+
150
+ "tensorboard": true, // enable tensorboard visualization
151
+ }
152
+ }
153
+ ```
154
+
155
+ Add addional configurations if you need.
156
+
157
+ ### Using config files
158
+ Modify the configurations in `.json` config files, then run:
159
+
160
+ ```
161
+ python train.py --config config.json
162
+ ```
163
+
164
+ ### Resuming from checkpoints
165
+ You can resume from a previously saved checkpoint by:
166
+
167
+ ```
168
+ python train.py --resume path/to/checkpoint
169
+ ```
170
+
171
+ ### Using Multiple GPU
172
+ You can enable multi-GPU training by setting `n_gpu` argument of the config file to larger number.
173
+ If configured to use smaller number of gpu than available, first n devices will be used by default.
174
+ Specify indices of available GPUs by cuda environmental variable.
175
+ ```
176
+ python train.py --device 2,3 -c config.json
177
+ ```
178
+ This is equivalent to
179
+ ```
180
+ CUDA_VISIBLE_DEVICES=2,3 python train.py -c config.py
181
+ ```
182
+
183
+ ## Customization
184
+
185
+ ### Project initialization
186
+ Use the `new_project.py` script to make your new project directory with template files.
187
+ `python new_project.py ../NewProject` then a new project folder named 'NewProject' will be made.
188
+ This script will filter out unneccessary files like cache, git files or readme file.
189
+
190
+ ### Custom CLI options
191
+
192
+ Changing values of config file is a clean, safe and easy way of tuning hyperparameters. However, sometimes
193
+ it is better to have command line options if some values need to be changed too often or quickly.
194
+
195
+ This template uses the configurations stored in the json file by default, but by registering custom options as follows
196
+ you can change some of them using CLI flags.
197
+
198
+ ```python
199
+ # simple class-like object having 3 attributes, `flags`, `type`, `target`.
200
+ CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
201
+ options = [
202
+ CustomArgs(['--lr', '--learning_rate'], type=float, target=('optimizer', 'args', 'lr')),
203
+ CustomArgs(['--bs', '--batch_size'], type=int, target=('data_loader', 'args', 'batch_size'))
204
+ # options added here can be modified by command line flags.
205
+ ]
206
+ ```
207
+ `target` argument should be sequence of keys, which are used to access that option in the config dict. In this example, `target`
208
+ for the learning rate option is `('optimizer', 'args', 'lr')` because `config['optimizer']['args']['lr']` points to the learning rate.
209
+ `python train.py -c config.json --bs 256` runs training with options given in `config.json` except for the `batch size`
210
+ which is increased to 256 by command line options.
211
+
212
+
213
+ ### Data Loader
214
+ * **Writing your own data loader**
215
+
216
+ 1. **Inherit ```BaseDataLoader```**
217
+
218
+ `BaseDataLoader` is a subclass of `torch.utils.data.DataLoader`, you can use either of them.
219
+
220
+ `BaseDataLoader` handles:
221
+ * Generating next batch
222
+ * Data shuffling
223
+ * Generating validation data loader by calling
224
+ `BaseDataLoader.split_validation()`
225
+
226
+ * **DataLoader Usage**
227
+
228
+ `BaseDataLoader` is an iterator, to iterate through batches:
229
+ ```python
230
+ for batch_idx, (x_batch, y_batch) in data_loader:
231
+ pass
232
+ ```
233
+ * **Example**
234
+
235
+ Please refer to `data_loader/data_loaders.py` for an MNIST data loading example.
236
+
237
+ ### Trainer
238
+ * **Writing your own trainer**
239
+
240
+ 1. **Inherit ```BaseTrainer```**
241
+
242
+ `BaseTrainer` handles:
243
+ * Training process logging
244
+ * Checkpoint saving
245
+ * Checkpoint resuming
246
+ * Reconfigurable performance monitoring for saving current best model, and early stop training.
247
+ * If config `monitor` is set to `max val_accuracy`, which means then the trainer will save a checkpoint `model_best.pth` when `validation accuracy` of epoch replaces current `maximum`.
248
+ * If config `early_stop` is set, training will be automatically terminated when model performance does not improve for given number of epochs. This feature can be turned off by passing 0 to the `early_stop` option, or just deleting the line of config.
249
+
250
+ 2. **Implementing abstract methods**
251
+
252
+ You need to implement `_train_epoch()` for your training process, if you need validation then you can implement `_valid_epoch()` as in `trainer/trainer.py`
253
+
254
+ * **Example**
255
+
256
+ Please refer to `trainer/trainer.py` for MNIST training.
257
+
258
+ * **Iteration-based training**
259
+
260
+ `Trainer.__init__` takes an optional argument, `len_epoch` which controls number of batches(steps) in each epoch.
261
+
262
+ ### Model
263
+ * **Writing your own model**
264
+
265
+ 1. **Inherit `BaseModel`**
266
+
267
+ `BaseModel` handles:
268
+ * Inherited from `torch.nn.Module`
269
+ * `__str__`: Modify native `print` function to prints the number of trainable parameters.
270
+
271
+ 2. **Implementing abstract methods**
272
+
273
+ Implement the foward pass method `forward()`
274
+
275
+ * **Example**
276
+
277
+ Please refer to `model/model.py` for a LeNet example.
278
+
279
+ ### Loss
280
+ Custom loss functions can be implemented in 'model/loss.py'. Use them by changing the name given in "loss" in config file, to corresponding name.
281
+
282
+ ### Metrics
283
+ Metric functions are located in 'model/metric.py'.
284
+
285
+ You can monitor multiple metrics by providing a list in the configuration file, e.g.:
286
+ ```json
287
+ "metrics": ["accuracy", "top_k_acc"],
288
+ ```
289
+
290
+ ### Additional logging
291
+ If you have additional information to be logged, in `_train_epoch()` of your trainer class, merge them with `log` as shown below before returning:
292
+
293
+ ```python
294
+ additional_log = {"gradient_norm": g, "sensitivity": s}
295
+ log.update(additional_log)
296
+ return log
297
+ ```
298
+
299
+ ### Testing
300
+ You can test trained model by running `test.py` passing path to the trained checkpoint by `--resume` argument.
301
+
302
+ ### Validation data
303
+ To split validation data from a data loader, call `BaseDataLoader.split_validation()`, then it will return a data loader for validation of size specified in your config file.
304
+ The `validation_split` can be a ratio of validation set per total data(0.0 <= float < 1.0), or the number of samples (0 <= int < `n_total_samples`).
305
+
306
+ **Note**: the `split_validation()` method will modify the original data loader
307
+ **Note**: `split_validation()` will return `None` if `"validation_split"` is set to `0`
308
+
309
+ ### Checkpoints
310
+ You can specify the name of the training session in config files:
311
+ ```json
312
+ "name": "MNIST_LeNet",
313
+ ```
314
+
315
+ The checkpoints will be saved in `save_dir/name/timestamp/checkpoint_epoch_n`, with timestamp in mmdd_HHMMSS format.
316
+
317
+ A copy of config file will be saved in the same folder.
318
+
319
+ **Note**: checkpoints contain:
320
+ ```python
321
+ {
322
+ 'arch': arch,
323
+ 'epoch': epoch,
324
+ 'state_dict': self.model.state_dict(),
325
+ 'optimizer': self.optimizer.state_dict(),
326
+ 'monitor_best': self.mnt_best,
327
+ 'config': self.config
328
+ }
329
+ ```
330
+
331
+ ### Tensorboard Visualization
332
+ This template supports Tensorboard visualization by using either `torch.utils.tensorboard` or [TensorboardX](https://github.com/lanpa/tensorboardX).
333
+
334
+ 1. **Install**
335
+
336
+ If you are using pytorch 1.1 or higher, install tensorboard by 'pip install tensorboard>=1.14.0'.
337
+
338
+ Otherwise, you should install tensorboardx. Follow installation guide in [TensorboardX](https://github.com/lanpa/tensorboardX).
339
+
340
+ 2. **Run training**
341
+
342
+ Make sure that `tensorboard` option in the config file is turned on.
343
+
344
+ ```
345
+ "tensorboard" : true
346
+ ```
347
+
348
+ 3. **Open Tensorboard server**
349
+
350
+ Type `tensorboard --logdir saved/log/` at the project root, then server will open at `http://localhost:6006`
351
+
352
+ By default, values of loss and metrics specified in config file, input images, and histogram of model parameters will be logged.
353
+ If you need more visualizations, use `add_scalar('tag', data)`, `add_image('tag', image)`, etc in the `trainer._train_epoch` method.
354
+ `add_something()` methods in this template are basically wrappers for those of `tensorboardX.SummaryWriter` and `torch.utils.tensorboard.SummaryWriter` modules.
355
+
356
+ **Note**: You don't have to specify current steps, since `WriterTensorboard` class defined at `logger/visualization.py` will track current steps.
357
+
358
+ ## Contribution
359
+ Feel free to contribute any kind of function or enhancement, here the coding style follows PEP8
360
+
361
+ Code should pass the [Flake8](http://flake8.pycqa.org/en/latest/) check before committing.
362
+
363
+ ## TODOs
364
+
365
+ - [ ] Multiple optimizers
366
+ - [ ] Support more tensorboard functions
367
+ - [x] Using fixed random seed
368
+ - [x] Support pytorch native tensorboard
369
+ - [x] `tensorboardX` logger support
370
+ - [x] Configurable logging layout, checkpoint naming
371
+ - [x] Iteration-based training (instead of epoch-based)
372
+ - [x] Adding command line option for fine-tuning
373
+
374
+ ## License
375
+ This project is licensed under the MIT License. See LICENSE for more details
376
+
377
+ ## Acknowledgements
378
+ This project is inspired by the project [Tensorflow-Project-Template](https://github.com/MrGemy95/Tensorflow-Project-Template) by [Mahmoud Gemy](https://github.com/MrGemy95)
model/base/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .base_data_loader import *
2
+ from .base_model import *
3
+ from .base_trainer import *
model/base/base_data_loader.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from torch.utils.data import DataLoader
3
+ from torch.utils.data.dataloader import default_collate
4
+ from torch.utils.data.sampler import SubsetRandomSampler
5
+
6
+
7
+ class BaseDataLoader(DataLoader):
8
+ """
9
+ Base class for all data loaders
10
+ """
11
+ def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate):
12
+ self.validation_split = validation_split
13
+ self.shuffle = shuffle
14
+
15
+ self.batch_idx = 0
16
+ self.n_samples = len(dataset)
17
+
18
+ self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)
19
+
20
+ self.init_kwargs = {
21
+ 'dataset': dataset,
22
+ 'batch_size': batch_size,
23
+ 'shuffle': self.shuffle,
24
+ 'collate_fn': collate_fn,
25
+ 'num_workers': num_workers
26
+ }
27
+ super().__init__(sampler=self.sampler, **self.init_kwargs)
28
+
29
+ def _split_sampler(self, split):
30
+ if split == 0.0:
31
+ return None, None
32
+
33
+ idx_full = np.arange(self.n_samples)
34
+
35
+ np.random.seed(0)
36
+ np.random.shuffle(idx_full)
37
+
38
+ if isinstance(split, int):
39
+ assert split > 0
40
+ assert split < self.n_samples, "validation set size is configured to be larger than entire dataset."
41
+ len_valid = split
42
+ else:
43
+ len_valid = int(self.n_samples * split)
44
+
45
+ valid_idx = idx_full[0:len_valid]
46
+ train_idx = np.delete(idx_full, np.arange(0, len_valid))
47
+
48
+ train_sampler = SubsetRandomSampler(train_idx)
49
+ valid_sampler = SubsetRandomSampler(valid_idx)
50
+
51
+ # turn off shuffle option which is mutually exclusive with sampler
52
+ self.shuffle = False
53
+ self.n_samples = len(train_idx)
54
+
55
+ return train_sampler, valid_sampler
56
+
57
+ def split_validation(self):
58
+ if self.valid_sampler is None:
59
+ return None
60
+ else:
61
+ return DataLoader(sampler=self.valid_sampler, **self.init_kwargs)
model/base/base_model.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import numpy as np
3
+ from abc import abstractmethod
4
+
5
+
6
+ class BaseModel(nn.Module):
7
+ """
8
+ Base class for all models
9
+ """
10
+ @abstractmethod
11
+ def forward(self, *inputs):
12
+ """
13
+ Forward pass logic
14
+
15
+ :return: Model output
16
+ """
17
+ raise NotImplementedError
18
+
19
+ def __str__(self):
20
+ """
21
+ Model prints with number of trainable parameters
22
+ """
23
+ model_parameters = filter(lambda p: p.requires_grad, self.parameters())
24
+ params = sum([np.prod(p.size()) for p in model_parameters])
25
+ return super().__str__() + '\nTrainable parameters: {}'.format(params)
model/base/base_trainer.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from abc import abstractmethod
3
+ from numpy import inf
4
+ from logger import TensorboardWriter
5
+
6
+
7
+ class BaseTrainer:
8
+ """
9
+ Base class for all trainers
10
+ """
11
+ def __init__(self, model, criterion, metric_ftns, optimizer, config):
12
+ self.config = config
13
+ self.logger = config.get_logger('trainer', config['trainer']['verbosity'])
14
+
15
+ self.model = model
16
+ self.criterion = criterion
17
+ self.metric_ftns = metric_ftns
18
+ self.optimizer = optimizer
19
+
20
+ cfg_trainer = config['trainer']
21
+ self.epochs = cfg_trainer['epochs']
22
+ self.save_period = cfg_trainer['save_period']
23
+ self.monitor = cfg_trainer.get('monitor', 'off')
24
+
25
+ # configuration to monitor model performance and save best
26
+ if self.monitor == 'off':
27
+ self.mnt_mode = 'off'
28
+ self.mnt_best = 0
29
+ else:
30
+ self.mnt_mode, self.mnt_metric = self.monitor.split()
31
+ assert self.mnt_mode in ['min', 'max']
32
+
33
+ self.mnt_best = inf if self.mnt_mode == 'min' else -inf
34
+ self.early_stop = cfg_trainer.get('early_stop', inf)
35
+ if self.early_stop <= 0:
36
+ self.early_stop = inf
37
+
38
+ self.start_epoch = 1
39
+
40
+ self.checkpoint_dir = config.save_dir
41
+
42
+ # setup visualization writer instance
43
+ self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard'])
44
+
45
+ if config.resume is not None:
46
+ self._resume_checkpoint(config.resume)
47
+
48
+ @abstractmethod
49
+ def _train_epoch(self, epoch):
50
+ """
51
+ Training logic for an epoch
52
+
53
+ :param epoch: Current epoch number
54
+ """
55
+ raise NotImplementedError
56
+
57
+ def train(self):
58
+ """
59
+ Full training logic
60
+ """
61
+ not_improved_count = 0
62
+ for epoch in range(self.start_epoch, self.epochs + 1):
63
+ result = self._train_epoch(epoch)
64
+
65
+ # save logged informations into log dict
66
+ log = {'epoch': epoch}
67
+ log.update(result)
68
+
69
+ # print logged informations to the screen
70
+ for key, value in log.items():
71
+ self.logger.info(' {:15s}: {}'.format(str(key), value))
72
+
73
+ # evaluate model performance according to configured metric, save best checkpoint as model_best
74
+ best = False
75
+ if self.mnt_mode != 'off':
76
+ try:
77
+ # check whether model performance improved or not, according to specified metric(mnt_metric)
78
+ improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \
79
+ (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best)
80
+ except KeyError:
81
+ self.logger.warning("Warning: Metric '{}' is not found. "
82
+ "Model performance monitoring is disabled.".format(self.mnt_metric))
83
+ self.mnt_mode = 'off'
84
+ improved = False
85
+
86
+ if improved:
87
+ self.mnt_best = log[self.mnt_metric]
88
+ not_improved_count = 0
89
+ best = True
90
+ else:
91
+ not_improved_count += 1
92
+
93
+ if not_improved_count > self.early_stop:
94
+ self.logger.info("Validation performance didn\'t improve for {} epochs. "
95
+ "Training stops.".format(self.early_stop))
96
+ break
97
+
98
+ if epoch % self.save_period == 0:
99
+ self._save_checkpoint(epoch, save_best=best)
100
+
101
+ def _save_checkpoint(self, epoch, save_best=False):
102
+ """
103
+ Saving checkpoints
104
+
105
+ :param epoch: current epoch number
106
+ :param log: logging information of the epoch
107
+ :param save_best: if True, rename the saved checkpoint to 'model_best.pth'
108
+ """
109
+ arch = type(self.model).__name__
110
+ state = {
111
+ 'arch': arch,
112
+ 'epoch': epoch,
113
+ 'state_dict': self.model.state_dict(),
114
+ 'optimizer': self.optimizer.state_dict(),
115
+ 'monitor_best': self.mnt_best,
116
+ 'config': self.config
117
+ }
118
+ filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch))
119
+ torch.save(state, filename)
120
+ self.logger.info("Saving checkpoint: {} ...".format(filename))
121
+ if save_best:
122
+ best_path = str(self.checkpoint_dir / 'model_best.pth')
123
+ torch.save(state, best_path)
124
+ self.logger.info("Saving current best: model_best.pth ...")
125
+
126
+ def _resume_checkpoint(self, resume_path):
127
+ """
128
+ Resume from saved checkpoints
129
+
130
+ :param resume_path: Checkpoint path to be resumed
131
+ """
132
+ resume_path = str(resume_path)
133
+ self.logger.info("Loading checkpoint: {} ...".format(resume_path))
134
+ checkpoint = torch.load(resume_path)
135
+ self.start_epoch = checkpoint['epoch'] + 1
136
+ self.mnt_best = checkpoint['monitor_best']
137
+
138
+ # load architecture params from checkpoint.
139
+ if checkpoint['config']['arch'] != self.config['arch']:
140
+ self.logger.warning("Warning: Architecture configuration given in config file is different from that of "
141
+ "checkpoint. This may yield an exception while state_dict is being loaded.")
142
+ self.model.load_state_dict(checkpoint['state_dict'])
143
+
144
+ # load optimizer state from checkpoint only when optimizer type is not changed.
145
+ if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']:
146
+ self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. "
147
+ "Optimizer parameters not being resumed.")
148
+ else:
149
+ self.optimizer.load_state_dict(checkpoint['optimizer'])
150
+
151
+ self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch))
model/config.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "Mnist_LeNet",
3
+ "n_gpu": 1,
4
+
5
+ "arch": {
6
+ "type": "MnistModel",
7
+ "args": {}
8
+ },
9
+ "data_loader": {
10
+ "type": "MnistDataLoader",
11
+ "args":{
12
+ "data_dir": "data/",
13
+ "batch_size": 128,
14
+ "shuffle": true,
15
+ "validation_split": 0.1,
16
+ "num_workers": 2
17
+ }
18
+ },
19
+ "optimizer": {
20
+ "type": "Adam",
21
+ "args":{
22
+ "lr": 0.001,
23
+ "weight_decay": 0,
24
+ "amsgrad": true
25
+ }
26
+ },
27
+ "loss": "nll_loss",
28
+ "metrics": [
29
+ "accuracy", "top_k_acc"
30
+ ],
31
+ "lr_scheduler": {
32
+ "type": "StepLR",
33
+ "args": {
34
+ "step_size": 50,
35
+ "gamma": 0.1
36
+ }
37
+ },
38
+ "trainer": {
39
+ "epochs": 100,
40
+
41
+ "save_dir": "saved/",
42
+ "save_period": 1,
43
+ "verbosity": 2,
44
+
45
+ "monitor": "min val_loss",
46
+ "early_stop": 10,
47
+
48
+ "tensorboard": true
49
+ }
50
+ }
model/data_loader/data_loaders.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import datasets, transforms
2
+ from base import BaseDataLoader
3
+
4
+
5
+ class MnistDataLoader(BaseDataLoader):
6
+ """
7
+ MNIST data loading demo using BaseDataLoader
8
+ """
9
+ def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_workers=1, training=True):
10
+ trsfm = transforms.Compose([
11
+ transforms.ToTensor(),
12
+ transforms.Normalize((0.1307,), (0.3081,))
13
+ ])
14
+ self.data_dir = data_dir
15
+ self.dataset = datasets.MNIST(self.data_dir, train=training, download=True, transform=trsfm)
16
+ super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers)
model/logger/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .logger import *
2
+ from .visualization import *
model/logger/logger.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import logging.config
3
+ from pathlib import Path
4
+ from utils import read_json
5
+
6
+
7
+ def setup_logging(save_dir, log_config='logger/logger_config.json', default_level=logging.INFO):
8
+ """
9
+ Setup logging configuration
10
+ """
11
+ log_config = Path(log_config)
12
+ if log_config.is_file():
13
+ config = read_json(log_config)
14
+ # modify logging paths based on run config
15
+ for _, handler in config['handlers'].items():
16
+ if 'filename' in handler:
17
+ handler['filename'] = str(save_dir / handler['filename'])
18
+
19
+ logging.config.dictConfig(config)
20
+ else:
21
+ print("Warning: logging configuration file is not found in {}.".format(log_config))
22
+ logging.basicConfig(level=default_level)
model/logger/logger_config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ {
3
+ "version": 1,
4
+ "disable_existing_loggers": false,
5
+ "formatters": {
6
+ "simple": {"format": "%(message)s"},
7
+ "datetime": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"}
8
+ },
9
+ "handlers": {
10
+ "console": {
11
+ "class": "logging.StreamHandler",
12
+ "level": "DEBUG",
13
+ "formatter": "simple",
14
+ "stream": "ext://sys.stdout"
15
+ },
16
+ "info_file_handler": {
17
+ "class": "logging.handlers.RotatingFileHandler",
18
+ "level": "INFO",
19
+ "formatter": "datetime",
20
+ "filename": "info.log",
21
+ "maxBytes": 10485760,
22
+ "backupCount": 20, "encoding": "utf8"
23
+ }
24
+ },
25
+ "root": {
26
+ "level": "INFO",
27
+ "handlers": [
28
+ "console",
29
+ "info_file_handler"
30
+ ]
31
+ }
32
+ }
model/logger/visualization.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from datetime import datetime
3
+
4
+
5
+ class TensorboardWriter():
6
+ def __init__(self, log_dir, logger, enabled):
7
+ self.writer = None
8
+ self.selected_module = ""
9
+
10
+ if enabled:
11
+ log_dir = str(log_dir)
12
+
13
+ # Retrieve vizualization writer.
14
+ succeeded = False
15
+ for module in ["torch.utils.tensorboard", "tensorboardX"]:
16
+ try:
17
+ self.writer = importlib.import_module(module).SummaryWriter(log_dir)
18
+ succeeded = True
19
+ break
20
+ except ImportError:
21
+ succeeded = False
22
+ self.selected_module = module
23
+
24
+ if not succeeded:
25
+ message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \
26
+ "this machine. Please install TensorboardX with 'pip install tensorboardx', upgrade PyTorch to " \
27
+ "version >= 1.1 to use 'torch.utils.tensorboard' or turn off the option in the 'config.json' file."
28
+ logger.warning(message)
29
+
30
+ self.step = 0
31
+ self.mode = ''
32
+
33
+ self.tb_writer_ftns = {
34
+ 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio',
35
+ 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding'
36
+ }
37
+ self.tag_mode_exceptions = {'add_histogram', 'add_embedding'}
38
+ self.timer = datetime.now()
39
+
40
+ def set_step(self, step, mode='train'):
41
+ self.mode = mode
42
+ self.step = step
43
+ if step == 0:
44
+ self.timer = datetime.now()
45
+ else:
46
+ duration = datetime.now() - self.timer
47
+ self.add_scalar('steps_per_sec', 1 / duration.total_seconds())
48
+ self.timer = datetime.now()
49
+
50
+ def __getattr__(self, name):
51
+ """
52
+ If visualization is configured to use:
53
+ return add_data() methods of tensorboard with additional information (step, tag) added.
54
+ Otherwise:
55
+ return a blank function handle that does nothing
56
+ """
57
+ if name in self.tb_writer_ftns:
58
+ add_data = getattr(self.writer, name, None)
59
+
60
+ def wrapper(tag, data, *args, **kwargs):
61
+ if add_data is not None:
62
+ # add mode(train/valid) tag
63
+ if name not in self.tag_mode_exceptions:
64
+ tag = '{}/{}'.format(tag, self.mode)
65
+ add_data(tag, data, self.step, *args, **kwargs)
66
+ return wrapper
67
+ else:
68
+ # default action for returning methods defined in this class, set_step() for instance.
69
+ try:
70
+ attr = object.__getattr__(name)
71
+ except AttributeError:
72
+ raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name))
73
+ return attr
model/model/loss.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+
3
+
4
+ def nll_loss(output, target):
5
+ return F.nll_loss(output, target)
model/model/metric.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def accuracy(output, target):
5
+ with torch.no_grad():
6
+ pred = torch.argmax(output, dim=1)
7
+ assert pred.shape[0] == len(target)
8
+ correct = 0
9
+ correct += torch.sum(pred == target).item()
10
+ return correct / len(target)
11
+
12
+
13
+ def top_k_acc(output, target, k=3):
14
+ with torch.no_grad():
15
+ pred = torch.topk(output, k, dim=1)[1]
16
+ assert pred.shape[0] == len(target)
17
+ correct = 0
18
+ for i in range(k):
19
+ correct += torch.sum(pred[:, i] == target).item()
20
+ return correct / len(target)
model/model/model.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ from base import BaseModel
4
+
5
+
6
+ class MnistModel(BaseModel):
7
+ def __init__(self, num_classes=10):
8
+ super().__init__()
9
+ self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
10
+ self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
11
+ self.conv2_drop = nn.Dropout2d()
12
+ self.fc1 = nn.Linear(320, 50)
13
+ self.fc2 = nn.Linear(50, num_classes)
14
+
15
+ def forward(self, x):
16
+ x = F.relu(F.max_pool2d(self.conv1(x), 2))
17
+ x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
18
+ x = x.view(-1, 320)
19
+ x = F.relu(self.fc1(x))
20
+ x = F.dropout(x, training=self.training)
21
+ x = self.fc2(x)
22
+ return F.log_softmax(x, dim=1)
model/new_project.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+ from shutil import copytree, ignore_patterns
4
+
5
+
6
+ # This script initializes new pytorch project with the template files.
7
+ # Run `python3 new_project.py ../MyNewProject` then new project named
8
+ # MyNewProject will be made
9
+ current_dir = Path()
10
+ assert (current_dir / 'new_project.py').is_file(), 'Script should be executed in the pytorch-template directory'
11
+ assert len(sys.argv) == 2, 'Specify a name for the new project. Example: python3 new_project.py MyNewProject'
12
+
13
+ project_name = Path(sys.argv[1])
14
+ target_dir = current_dir / project_name
15
+
16
+ ignore = [".git", "data", "saved", "new_project.py", "LICENSE", ".flake8", "README.md", "__pycache__"]
17
+ copytree(current_dir, target_dir, ignore=ignore_patterns(*ignore))
18
+ print('New project initialized at', target_dir.absolute().resolve())
model/parse_config.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from pathlib import Path
4
+ from functools import reduce, partial
5
+ from operator import getitem
6
+ from datetime import datetime
7
+ from logger import setup_logging
8
+ from utils import read_json, write_json
9
+
10
+
11
+ class ConfigParser:
12
+ def __init__(self, config, resume=None, modification=None, run_id=None):
13
+ """
14
+ class to parse configuration json file. Handles hyperparameters for training, initializations of modules, checkpoint saving
15
+ and logging module.
16
+ :param config: Dict containing configurations, hyperparameters for training. contents of `config.json` file for example.
17
+ :param resume: String, path to the checkpoint being loaded.
18
+ :param modification: Dict keychain:value, specifying position values to be replaced from config dict.
19
+ :param run_id: Unique Identifier for training processes. Used to save checkpoints and training log. Timestamp is being used as default
20
+ """
21
+ # load config file and apply modification
22
+ self._config = _update_config(config, modification)
23
+ self.resume = resume
24
+
25
+ # set save_dir where trained model and log will be saved.
26
+ save_dir = Path(self.config['trainer']['save_dir'])
27
+
28
+ exper_name = self.config['name']
29
+ if run_id is None: # use timestamp as default run-id
30
+ run_id = datetime.now().strftime(r'%m%d_%H%M%S')
31
+ self._save_dir = save_dir / 'models' / exper_name / run_id
32
+ self._log_dir = save_dir / 'log' / exper_name / run_id
33
+
34
+ # make directory for saving checkpoints and log.
35
+ exist_ok = run_id == ''
36
+ self.save_dir.mkdir(parents=True, exist_ok=exist_ok)
37
+ self.log_dir.mkdir(parents=True, exist_ok=exist_ok)
38
+
39
+ # save updated config file to the checkpoint dir
40
+ write_json(self.config, self.save_dir / 'config.json')
41
+
42
+ # configure logging module
43
+ setup_logging(self.log_dir)
44
+ self.log_levels = {
45
+ 0: logging.WARNING,
46
+ 1: logging.INFO,
47
+ 2: logging.DEBUG
48
+ }
49
+
50
+ @classmethod
51
+ def from_args(cls, args, options=''):
52
+ """
53
+ Initialize this class from some cli arguments. Used in train, test.
54
+ """
55
+ for opt in options:
56
+ args.add_argument(*opt.flags, default=None, type=opt.type)
57
+ if not isinstance(args, tuple):
58
+ args = args.parse_args()
59
+
60
+ if args.device is not None:
61
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.device
62
+ if args.resume is not None:
63
+ resume = Path(args.resume)
64
+ cfg_fname = resume.parent / 'config.json'
65
+ else:
66
+ msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example."
67
+ assert args.config is not None, msg_no_cfg
68
+ resume = None
69
+ cfg_fname = Path(args.config)
70
+
71
+ config = read_json(cfg_fname)
72
+ if args.config and resume:
73
+ # update new config for fine-tuning
74
+ config.update(read_json(args.config))
75
+
76
+ # parse custom cli options into dictionary
77
+ modification = {opt.target : getattr(args, _get_opt_name(opt.flags)) for opt in options}
78
+ return cls(config, resume, modification)
79
+
80
+ def init_obj(self, name, module, *args, **kwargs):
81
+ """
82
+ Finds a function handle with the name given as 'type' in config, and returns the
83
+ instance initialized with corresponding arguments given.
84
+
85
+ `object = config.init_obj('name', module, a, b=1)`
86
+ is equivalent to
87
+ `object = module.name(a, b=1)`
88
+ """
89
+ module_name = self[name]['type']
90
+ module_args = dict(self[name]['args'])
91
+ assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
92
+ module_args.update(kwargs)
93
+ return getattr(module, module_name)(*args, **module_args)
94
+
95
+ def init_ftn(self, name, module, *args, **kwargs):
96
+ """
97
+ Finds a function handle with the name given as 'type' in config, and returns the
98
+ function with given arguments fixed with functools.partial.
99
+
100
+ `function = config.init_ftn('name', module, a, b=1)`
101
+ is equivalent to
102
+ `function = lambda *args, **kwargs: module.name(a, *args, b=1, **kwargs)`.
103
+ """
104
+ module_name = self[name]['type']
105
+ module_args = dict(self[name]['args'])
106
+ assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
107
+ module_args.update(kwargs)
108
+ return partial(getattr(module, module_name), *args, **module_args)
109
+
110
+ def __getitem__(self, name):
111
+ """Access items like ordinary dict."""
112
+ return self.config[name]
113
+
114
+ def get_logger(self, name, verbosity=2):
115
+ msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity, self.log_levels.keys())
116
+ assert verbosity in self.log_levels, msg_verbosity
117
+ logger = logging.getLogger(name)
118
+ logger.setLevel(self.log_levels[verbosity])
119
+ return logger
120
+
121
+ # setting read-only attributes
122
+ @property
123
+ def config(self):
124
+ return self._config
125
+
126
+ @property
127
+ def save_dir(self):
128
+ return self._save_dir
129
+
130
+ @property
131
+ def log_dir(self):
132
+ return self._log_dir
133
+
134
+ # helper functions to update config dict with custom cli options
135
+ def _update_config(config, modification):
136
+ if modification is None:
137
+ return config
138
+
139
+ for k, v in modification.items():
140
+ if v is not None:
141
+ _set_by_path(config, k, v)
142
+ return config
143
+
144
+ def _get_opt_name(flags):
145
+ for flg in flags:
146
+ if flg.startswith('--'):
147
+ return flg.replace('--', '')
148
+ return flags[0].replace('--', '')
149
+
150
+ def _set_by_path(tree, keys, value):
151
+ """Set a value in a nested object in tree by sequence of keys."""
152
+ keys = keys.split(';')
153
+ _get_by_path(tree, keys[:-1])[keys[-1]] = value
154
+
155
+ def _get_by_path(tree, keys):
156
+ """Access a nested object in tree by sequence of keys."""
157
+ return reduce(getitem, keys, tree)
model/requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch>=1.1
2
+ torchvision
3
+ numpy
4
+ tqdm
5
+ tensorboard>=1.14
model/test.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ from tqdm import tqdm
4
+ import data_loader.data_loaders as module_data
5
+ import model.loss as module_loss
6
+ import model.metric as module_metric
7
+ import model.model as module_arch
8
+ from parse_config import ConfigParser
9
+
10
+
11
+ def main(config):
12
+ logger = config.get_logger('test')
13
+
14
+ # setup data_loader instances
15
+ data_loader = getattr(module_data, config['data_loader']['type'])(
16
+ config['data_loader']['args']['data_dir'],
17
+ batch_size=512,
18
+ shuffle=False,
19
+ validation_split=0.0,
20
+ training=False,
21
+ num_workers=2
22
+ )
23
+
24
+ # build model architecture
25
+ model = config.init_obj('arch', module_arch)
26
+ logger.info(model)
27
+
28
+ # get function handles of loss and metrics
29
+ loss_fn = getattr(module_loss, config['loss'])
30
+ metric_fns = [getattr(module_metric, met) for met in config['metrics']]
31
+
32
+ logger.info('Loading checkpoint: {} ...'.format(config.resume))
33
+ checkpoint = torch.load(config.resume)
34
+ state_dict = checkpoint['state_dict']
35
+ if config['n_gpu'] > 1:
36
+ model = torch.nn.DataParallel(model)
37
+ model.load_state_dict(state_dict)
38
+
39
+ # prepare model for testing
40
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
41
+ model = model.to(device)
42
+ model.eval()
43
+
44
+ total_loss = 0.0
45
+ total_metrics = torch.zeros(len(metric_fns))
46
+
47
+ with torch.no_grad():
48
+ for i, (data, target) in enumerate(tqdm(data_loader)):
49
+ data, target = data.to(device), target.to(device)
50
+ output = model(data)
51
+
52
+ #
53
+ # save sample images, or do something with output here
54
+ #
55
+
56
+ # computing loss, metrics on test set
57
+ loss = loss_fn(output, target)
58
+ batch_size = data.shape[0]
59
+ total_loss += loss.item() * batch_size
60
+ for i, metric in enumerate(metric_fns):
61
+ total_metrics[i] += metric(output, target) * batch_size
62
+
63
+ n_samples = len(data_loader.sampler)
64
+ log = {'loss': total_loss / n_samples}
65
+ log.update({
66
+ met.__name__: total_metrics[i].item() / n_samples for i, met in enumerate(metric_fns)
67
+ })
68
+ logger.info(log)
69
+
70
+
71
+ if __name__ == '__main__':
72
+ args = argparse.ArgumentParser(description='PyTorch Template')
73
+ args.add_argument('-c', '--config', default=None, type=str,
74
+ help='config file path (default: None)')
75
+ args.add_argument('-r', '--resume', default=None, type=str,
76
+ help='path to latest checkpoint (default: None)')
77
+ args.add_argument('-d', '--device', default=None, type=str,
78
+ help='indices of GPUs to enable (default: all)')
79
+
80
+ config = ConfigParser.from_args(args)
81
+ main(config)
model/train.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import collections
3
+ import torch
4
+ import numpy as np
5
+ import data_loader.data_loaders as module_data
6
+ import model.loss as module_loss
7
+ import model.metric as module_metric
8
+ import model.model as module_arch
9
+ from parse_config import ConfigParser
10
+ from trainer import Trainer
11
+ from utils import prepare_device
12
+
13
+
14
+ # fix random seeds for reproducibility
15
+ SEED = 123
16
+ torch.manual_seed(SEED)
17
+ torch.backends.cudnn.deterministic = True
18
+ torch.backends.cudnn.benchmark = False
19
+ np.random.seed(SEED)
20
+
21
+ def main(config):
22
+ logger = config.get_logger('train')
23
+
24
+ # setup data_loader instances
25
+ data_loader = config.init_obj('data_loader', module_data)
26
+ valid_data_loader = data_loader.split_validation()
27
+
28
+ # build model architecture, then print to console
29
+ model = config.init_obj('arch', module_arch)
30
+ logger.info(model)
31
+
32
+ # prepare for (multi-device) GPU training
33
+ device, device_ids = prepare_device(config['n_gpu'])
34
+ model = model.to(device)
35
+ if len(device_ids) > 1:
36
+ model = torch.nn.DataParallel(model, device_ids=device_ids)
37
+
38
+ # get function handles of loss and metrics
39
+ criterion = getattr(module_loss, config['loss'])
40
+ metrics = [getattr(module_metric, met) for met in config['metrics']]
41
+
42
+ # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
43
+ trainable_params = filter(lambda p: p.requires_grad, model.parameters())
44
+ optimizer = config.init_obj('optimizer', torch.optim, trainable_params)
45
+ lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer)
46
+
47
+ trainer = Trainer(model, criterion, metrics, optimizer,
48
+ config=config,
49
+ device=device,
50
+ data_loader=data_loader,
51
+ valid_data_loader=valid_data_loader,
52
+ lr_scheduler=lr_scheduler)
53
+
54
+ trainer.train()
55
+
56
+
57
+ if __name__ == '__main__':
58
+ args = argparse.ArgumentParser(description='PyTorch Template')
59
+ args.add_argument('-c', '--config', default=None, type=str,
60
+ help='config file path (default: None)')
61
+ args.add_argument('-r', '--resume', default=None, type=str,
62
+ help='path to latest checkpoint (default: None)')
63
+ args.add_argument('-d', '--device', default=None, type=str,
64
+ help='indices of GPUs to enable (default: all)')
65
+
66
+ # custom cli options to modify configuration from default values given in json file.
67
+ CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
68
+ options = [
69
+ CustomArgs(['--lr', '--learning_rate'], type=float, target='optimizer;args;lr'),
70
+ CustomArgs(['--bs', '--batch_size'], type=int, target='data_loader;args;batch_size')
71
+ ]
72
+ config = ConfigParser.from_args(args, options)
73
+ main(config)
model/trainer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .trainer import *
model/trainer/trainer.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torchvision.utils import make_grid
4
+ from base import BaseTrainer
5
+ from utils import inf_loop, MetricTracker
6
+
7
+
8
+ class Trainer(BaseTrainer):
9
+ """
10
+ Trainer class
11
+ """
12
+ def __init__(self, model, criterion, metric_ftns, optimizer, config, device,
13
+ data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None):
14
+ super().__init__(model, criterion, metric_ftns, optimizer, config)
15
+ self.config = config
16
+ self.device = device
17
+ self.data_loader = data_loader
18
+ if len_epoch is None:
19
+ # epoch-based training
20
+ self.len_epoch = len(self.data_loader)
21
+ else:
22
+ # iteration-based training
23
+ self.data_loader = inf_loop(data_loader)
24
+ self.len_epoch = len_epoch
25
+ self.valid_data_loader = valid_data_loader
26
+ self.do_validation = self.valid_data_loader is not None
27
+ self.lr_scheduler = lr_scheduler
28
+ self.log_step = int(np.sqrt(data_loader.batch_size))
29
+
30
+ self.train_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)
31
+ self.valid_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)
32
+
33
+ def _train_epoch(self, epoch):
34
+ """
35
+ Training logic for an epoch
36
+
37
+ :param epoch: Integer, current training epoch.
38
+ :return: A log that contains average loss and metric in this epoch.
39
+ """
40
+ self.model.train()
41
+ self.train_metrics.reset()
42
+ for batch_idx, (data, target) in enumerate(self.data_loader):
43
+ data, target = data.to(self.device), target.to(self.device)
44
+
45
+ self.optimizer.zero_grad()
46
+ output = self.model(data)
47
+ loss = self.criterion(output, target)
48
+ loss.backward()
49
+ self.optimizer.step()
50
+
51
+ self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
52
+ self.train_metrics.update('loss', loss.item())
53
+ for met in self.metric_ftns:
54
+ self.train_metrics.update(met.__name__, met(output, target))
55
+
56
+ if batch_idx % self.log_step == 0:
57
+ self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
58
+ epoch,
59
+ self._progress(batch_idx),
60
+ loss.item()))
61
+ self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
62
+
63
+ if batch_idx == self.len_epoch:
64
+ break
65
+ log = self.train_metrics.result()
66
+
67
+ if self.do_validation:
68
+ val_log = self._valid_epoch(epoch)
69
+ log.update(**{'val_'+k : v for k, v in val_log.items()})
70
+
71
+ if self.lr_scheduler is not None:
72
+ self.lr_scheduler.step()
73
+ return log
74
+
75
+ def _valid_epoch(self, epoch):
76
+ """
77
+ Validate after training an epoch
78
+
79
+ :param epoch: Integer, current training epoch.
80
+ :return: A log that contains information about validation
81
+ """
82
+ self.model.eval()
83
+ self.valid_metrics.reset()
84
+ with torch.no_grad():
85
+ for batch_idx, (data, target) in enumerate(self.valid_data_loader):
86
+ data, target = data.to(self.device), target.to(self.device)
87
+
88
+ output = self.model(data)
89
+ loss = self.criterion(output, target)
90
+
91
+ self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid')
92
+ self.valid_metrics.update('loss', loss.item())
93
+ for met in self.metric_ftns:
94
+ self.valid_metrics.update(met.__name__, met(output, target))
95
+ self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
96
+
97
+ # add histogram of model parameters to the tensorboard
98
+ for name, p in self.model.named_parameters():
99
+ self.writer.add_histogram(name, p, bins='auto')
100
+ return self.valid_metrics.result()
101
+
102
+ def _progress(self, batch_idx):
103
+ base = '[{}/{} ({:.0f}%)]'
104
+ if hasattr(self.data_loader, 'n_samples'):
105
+ current = batch_idx * self.data_loader.batch_size
106
+ total = self.data_loader.n_samples
107
+ else:
108
+ current = batch_idx
109
+ total = self.len_epoch
110
+ return base.format(current, total, 100.0 * current / total)
model/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .util import *
model/utils/util.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ import pandas as pd
4
+ from pathlib import Path
5
+ from itertools import repeat
6
+ from collections import OrderedDict
7
+
8
+
9
+ def ensure_dir(dirname):
10
+ dirname = Path(dirname)
11
+ if not dirname.is_dir():
12
+ dirname.mkdir(parents=True, exist_ok=False)
13
+
14
+ def read_json(fname):
15
+ fname = Path(fname)
16
+ with fname.open('rt') as handle:
17
+ return json.load(handle, object_hook=OrderedDict)
18
+
19
+ def write_json(content, fname):
20
+ fname = Path(fname)
21
+ with fname.open('wt') as handle:
22
+ json.dump(content, handle, indent=4, sort_keys=False)
23
+
24
+ def inf_loop(data_loader):
25
+ ''' wrapper function for endless data loader. '''
26
+ for loader in repeat(data_loader):
27
+ yield from loader
28
+
29
+ def prepare_device(n_gpu_use):
30
+ """
31
+ setup GPU device if available. get gpu device indices which are used for DataParallel
32
+ """
33
+ n_gpu = torch.cuda.device_count()
34
+ if n_gpu_use > 0 and n_gpu == 0:
35
+ print("Warning: There\'s no GPU available on this machine,"
36
+ "training will be performed on CPU.")
37
+ n_gpu_use = 0
38
+ if n_gpu_use > n_gpu:
39
+ print(f"Warning: The number of GPU\'s configured to use is {n_gpu_use}, but only {n_gpu} are "
40
+ "available on this machine.")
41
+ n_gpu_use = n_gpu
42
+ device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
43
+ list_ids = list(range(n_gpu_use))
44
+ return device, list_ids
45
+
46
+ class MetricTracker:
47
+ def __init__(self, *keys, writer=None):
48
+ self.writer = writer
49
+ self._data = pd.DataFrame(index=keys, columns=['total', 'counts', 'average'])
50
+ self.reset()
51
+
52
+ def reset(self):
53
+ for col in self._data.columns:
54
+ self._data[col].values[:] = 0
55
+
56
+ def update(self, key, value, n=1):
57
+ if self.writer is not None:
58
+ self.writer.add_scalar(key, value)
59
+ self._data.total[key] += value * n
60
+ self._data.counts[key] += n
61
+ self._data.average[key] = self._data.total[key] / self._data.counts[key]
62
+
63
+ def avg(self, key):
64
+ return self._data.average[key]
65
+
66
+ def result(self):
67
+ return dict(self._data.average)