Hannes Kuchelmeister commited on
Commit
754b856
1 Parent(s): 90cf80d

rename metric to mae

Browse files
Files changed (1) hide show
  1. models/src/models/focus_module.py +16 -16
models/src/models/focus_module.py CHANGED
@@ -71,12 +71,12 @@ class FocusLitModule(LightningModule):
71
 
72
  # use separate metric instance for train, val and test step
73
  # to ensure a proper reduction over the epoch
74
- self.train_acc = MeanAbsoluteError()
75
- self.val_acc = MeanAbsoluteError()
76
- self.test_acc = MeanAbsoluteError()
77
 
78
  # for logging best so far validation accuracy
79
- self.val_acc_best = MinMetric()
80
 
81
  def forward(self, x: torch.Tensor):
82
  return self.model(x)
@@ -93,9 +93,9 @@ class FocusLitModule(LightningModule):
93
  loss, preds, targets = self.step(batch)
94
 
95
  # log train metrics
96
- acc = self.train_acc(preds, targets)
97
  self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
98
- self.log("train/acc", acc, on_step=False, on_epoch=True, prog_bar=True)
99
 
100
  # we can return here dict with any tensors
101
  # and then read it in some callback or in `training_epoch_end()`` below
@@ -110,26 +110,26 @@ class FocusLitModule(LightningModule):
110
  loss, preds, targets = self.step(batch)
111
 
112
  # log val metrics
113
- acc = self.val_acc(preds, targets)
114
  self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
115
- self.log("val/acc", acc, on_step=False, on_epoch=True, prog_bar=True)
116
 
117
  return {"loss": loss, "preds": preds, "targets": targets}
118
 
119
  def validation_epoch_end(self, outputs: List[Any]):
120
- acc = self.val_acc.compute() # get val accuracy from current epoch
121
- self.val_acc_best.update(acc)
122
  self.log(
123
- "val/acc_best", self.val_acc_best.compute(), on_epoch=True, prog_bar=True
124
  )
125
 
126
  def test_step(self, batch: Any, batch_idx: int):
127
  loss, preds, targets = self.step(batch)
128
 
129
  # log test metrics
130
- acc = self.test_acc(preds, targets)
131
  self.log("test/loss", loss, on_step=False, on_epoch=True)
132
- self.log("test/acc", acc, on_step=False, on_epoch=True)
133
 
134
  return {"loss": loss, "preds": preds, "targets": targets}
135
 
@@ -138,9 +138,9 @@ class FocusLitModule(LightningModule):
138
 
139
  def on_epoch_end(self):
140
  # reset metrics at the end of every epoch
141
- self.train_acc.reset()
142
- self.test_acc.reset()
143
- self.val_acc.reset()
144
 
145
  def configure_optimizers(self):
146
  """Choose what optimizers and learning-rate schedulers.
 
71
 
72
  # use separate metric instance for train, val and test step
73
  # to ensure a proper reduction over the epoch
74
+ self.train_mae = MeanAbsoluteError()
75
+ self.val_mae = MeanAbsoluteError()
76
+ self.test_mae = MeanAbsoluteError()
77
 
78
  # for logging best so far validation accuracy
79
+ self.val_mae_best = MinMetric()
80
 
81
  def forward(self, x: torch.Tensor):
82
  return self.model(x)
 
93
  loss, preds, targets = self.step(batch)
94
 
95
  # log train metrics
96
+ mae = self.train_mae(preds, targets)
97
  self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
98
+ self.log("train/mae", mae, on_step=False, on_epoch=True, prog_bar=True)
99
 
100
  # we can return here dict with any tensors
101
  # and then read it in some callback or in `training_epoch_end()`` below
 
110
  loss, preds, targets = self.step(batch)
111
 
112
  # log val metrics
113
+ mae = self.val_mae(preds, targets)
114
  self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
115
+ self.log("val/mae", mae, on_step=False, on_epoch=True, prog_bar=True)
116
 
117
  return {"loss": loss, "preds": preds, "targets": targets}
118
 
119
  def validation_epoch_end(self, outputs: List[Any]):
120
+ mae = self.val_mae.compute() # get val accuracy from current epoch
121
+ self.val_mae_best.update(mae)
122
  self.log(
123
+ "val/mae_best", self.val_mae_best.compute(), on_epoch=True, prog_bar=True
124
  )
125
 
126
  def test_step(self, batch: Any, batch_idx: int):
127
  loss, preds, targets = self.step(batch)
128
 
129
  # log test metrics
130
+ mae = self.test_mae(preds, targets)
131
  self.log("test/loss", loss, on_step=False, on_epoch=True)
132
+ self.log("test/mae", mae, on_step=False, on_epoch=True)
133
 
134
  return {"loss": loss, "preds": preds, "targets": targets}
135
 
 
138
 
139
  def on_epoch_end(self):
140
  # reset metrics at the end of every epoch
141
+ self.train_mae.reset()
142
+ self.test_mae.reset()
143
+ self.val_mae.reset()
144
 
145
  def configure_optimizers(self):
146
  """Choose what optimizers and learning-rate schedulers.