Hannes Kuchelmeister
commited on
Commit
•
754b856
1
Parent(s):
90cf80d
rename metric to mae
Browse files
models/src/models/focus_module.py
CHANGED
@@ -71,12 +71,12 @@ class FocusLitModule(LightningModule):
|
|
71 |
|
72 |
# use separate metric instance for train, val and test step
|
73 |
# to ensure a proper reduction over the epoch
|
74 |
-
self.
|
75 |
-
self.
|
76 |
-
self.
|
77 |
|
78 |
# for logging best so far validation accuracy
|
79 |
-
self.
|
80 |
|
81 |
def forward(self, x: torch.Tensor):
|
82 |
return self.model(x)
|
@@ -93,9 +93,9 @@ class FocusLitModule(LightningModule):
|
|
93 |
loss, preds, targets = self.step(batch)
|
94 |
|
95 |
# log train metrics
|
96 |
-
|
97 |
self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
|
98 |
-
self.log("train/
|
99 |
|
100 |
# we can return here dict with any tensors
|
101 |
# and then read it in some callback or in `training_epoch_end()`` below
|
@@ -110,26 +110,26 @@ class FocusLitModule(LightningModule):
|
|
110 |
loss, preds, targets = self.step(batch)
|
111 |
|
112 |
# log val metrics
|
113 |
-
|
114 |
self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
|
115 |
-
self.log("val/
|
116 |
|
117 |
return {"loss": loss, "preds": preds, "targets": targets}
|
118 |
|
119 |
def validation_epoch_end(self, outputs: List[Any]):
|
120 |
-
|
121 |
-
self.
|
122 |
self.log(
|
123 |
-
"val/
|
124 |
)
|
125 |
|
126 |
def test_step(self, batch: Any, batch_idx: int):
|
127 |
loss, preds, targets = self.step(batch)
|
128 |
|
129 |
# log test metrics
|
130 |
-
|
131 |
self.log("test/loss", loss, on_step=False, on_epoch=True)
|
132 |
-
self.log("test/
|
133 |
|
134 |
return {"loss": loss, "preds": preds, "targets": targets}
|
135 |
|
@@ -138,9 +138,9 @@ class FocusLitModule(LightningModule):
|
|
138 |
|
139 |
def on_epoch_end(self):
|
140 |
# reset metrics at the end of every epoch
|
141 |
-
self.
|
142 |
-
self.
|
143 |
-
self.
|
144 |
|
145 |
def configure_optimizers(self):
|
146 |
"""Choose what optimizers and learning-rate schedulers.
|
|
|
71 |
|
72 |
# use separate metric instance for train, val and test step
|
73 |
# to ensure a proper reduction over the epoch
|
74 |
+
self.train_mae = MeanAbsoluteError()
|
75 |
+
self.val_mae = MeanAbsoluteError()
|
76 |
+
self.test_mae = MeanAbsoluteError()
|
77 |
|
78 |
# for logging best so far validation accuracy
|
79 |
+
self.val_mae_best = MinMetric()
|
80 |
|
81 |
def forward(self, x: torch.Tensor):
|
82 |
return self.model(x)
|
|
|
93 |
loss, preds, targets = self.step(batch)
|
94 |
|
95 |
# log train metrics
|
96 |
+
mae = self.train_mae(preds, targets)
|
97 |
self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
|
98 |
+
self.log("train/mae", mae, on_step=False, on_epoch=True, prog_bar=True)
|
99 |
|
100 |
# we can return here dict with any tensors
|
101 |
# and then read it in some callback or in `training_epoch_end()`` below
|
|
|
110 |
loss, preds, targets = self.step(batch)
|
111 |
|
112 |
# log val metrics
|
113 |
+
mae = self.val_mae(preds, targets)
|
114 |
self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
|
115 |
+
self.log("val/mae", mae, on_step=False, on_epoch=True, prog_bar=True)
|
116 |
|
117 |
return {"loss": loss, "preds": preds, "targets": targets}
|
118 |
|
119 |
def validation_epoch_end(self, outputs: List[Any]):
|
120 |
+
mae = self.val_mae.compute() # get val accuracy from current epoch
|
121 |
+
self.val_mae_best.update(mae)
|
122 |
self.log(
|
123 |
+
"val/mae_best", self.val_mae_best.compute(), on_epoch=True, prog_bar=True
|
124 |
)
|
125 |
|
126 |
def test_step(self, batch: Any, batch_idx: int):
|
127 |
loss, preds, targets = self.step(batch)
|
128 |
|
129 |
# log test metrics
|
130 |
+
mae = self.test_mae(preds, targets)
|
131 |
self.log("test/loss", loss, on_step=False, on_epoch=True)
|
132 |
+
self.log("test/mae", mae, on_step=False, on_epoch=True)
|
133 |
|
134 |
return {"loss": loss, "preds": preds, "targets": targets}
|
135 |
|
|
|
138 |
|
139 |
def on_epoch_end(self):
|
140 |
# reset metrics at the end of every epoch
|
141 |
+
self.train_mae.reset()
|
142 |
+
self.test_mae.reset()
|
143 |
+
self.val_mae.reset()
|
144 |
|
145 |
def configure_optimizers(self):
|
146 |
"""Choose what optimizers and learning-rate schedulers.
|