Hannes Kuchelmeister
commited on
Commit
•
90cf80d
1
Parent(s):
ee280e5
fix predictions to not use argmax
Browse files
models/src/models/focus_module.py
CHANGED
@@ -86,7 +86,7 @@ class FocusLitModule(LightningModule):
|
|
86 |
y = batch["focus_value"]
|
87 |
logits = self.forward(x)
|
88 |
loss = self.criterion(logits, y)
|
89 |
-
preds = torch.
|
90 |
return loss, preds, y
|
91 |
|
92 |
def training_step(self, batch: Any, batch_idx: int):
|
|
|
86 |
y = batch["focus_value"]
|
87 |
logits = self.forward(x)
|
88 |
loss = self.criterion(logits, y)
|
89 |
+
preds = torch.squeeze(logits)
|
90 |
return loss, preds, y
|
91 |
|
92 |
def training_step(self, batch: Any, batch_idx: int):
|