Hannes Kuchelmeister commited on
Commit
90cf80d
1 Parent(s): ee280e5

fix predictions to not use argmax

Browse files
Files changed (1) hide show
  1. models/src/models/focus_module.py +1 -1
models/src/models/focus_module.py CHANGED
@@ -86,7 +86,7 @@ class FocusLitModule(LightningModule):
86
  y = batch["focus_value"]
87
  logits = self.forward(x)
88
  loss = self.criterion(logits, y)
89
- preds = torch.argmax(logits, dim=1)
90
  return loss, preds, y
91
 
92
  def training_step(self, batch: Any, batch_idx: int):
 
86
  y = batch["focus_value"]
87
  logits = self.forward(x)
88
  loss = self.criterion(logits, y)
89
+ preds = torch.squeeze(logits)
90
  return loss, preds, y
91
 
92
  def training_step(self, batch: Any, batch_idx: int):