rzimmerdev commited on
Commit
b358118
·
1 Parent(s): 82fdb01

Implemented training loops and logging functionalities with PyTorch Lightning

Browse files
Files changed (2) hide show
  1. notebooks/trainer.ipynb +4 -4
  2. src/trainer.py +45 -0
notebooks/trainer.ipynb CHANGED
@@ -3,14 +3,14 @@
3
  {
4
  "cell_type": "code",
5
  "execution_count": null,
 
 
6
  "metadata": {
7
- "collapsed": true,
8
  "pycharm": {
9
  "name": "#%%\n"
10
  }
11
- },
12
- "outputs": [],
13
- "source": []
14
  }
15
  ],
16
  "metadata": {
 
3
  {
4
  "cell_type": "code",
5
  "execution_count": null,
6
+ "outputs": [],
7
+ "source": [],
8
  "metadata": {
9
+ "collapsed": false,
10
  "pycharm": {
11
  "name": "#%%\n"
12
  }
13
+ }
 
 
14
  }
15
  ],
16
  "metadata": {
src/trainer.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+ import torch.optim
4
+ import pytorch_lightning as pl
5
+
6
+
7
+ class LitTrainer(pl.LightningModule):
8
+ def __init__(self, model, loss_fn, optim):
9
+ super().__init__()
10
+ self.model = model
11
+ self.loss_fn = loss_fn
12
+ self.optim = optim
13
+
14
+ def training_step(self, batch, batch_idx):
15
+ x, y = batch
16
+ x = x.to(torch.float32)
17
+
18
+ y_pred = self.model(x).reshape(1, -1)
19
+ train_loss = self.loss_fn(y_pred, y)
20
+
21
+ self.log("train_loss", train_loss)
22
+ return train_loss
23
+
24
+ def validation_step(self, batch, batch_idx):
25
+ # this is the validation loop
26
+ x, y = batch
27
+ x = x.to(torch.float32)
28
+
29
+ y_pred = self.model(x).reshape(1, -1)
30
+ validate_loss = self.loss_fn(y_pred, y)
31
+
32
+ self.log("val_loss", validate_loss)
33
+
34
+ def test_step(self, batch, batch_idx):
35
+ # this is the test loop
36
+ x, y = batch
37
+ x = x.to(torch.float32)
38
+
39
+ y_pred = self.model(x).reshape(1, -1)
40
+ test_loss = self.loss_fn(y_pred, y)
41
+
42
+ self.log("test_loss", test_loss)
43
+
44
+ def configure_optimizers(self):
45
+ return self.optim