Hannes Kuchelmeister
commited on
Commit
·
ba9c868
1
Parent(s):
ad947ed
Add code so loss function uses torch.Size([x,1]) instead of torch.Size([x])
Browse filesThis was done to prevent the error message:
"Using a target size (torch.Size([64])) that is different to the input size (torch.Size([64, 1]))".
src/models/focus_module.py
CHANGED
|
@@ -85,7 +85,7 @@ class FocusLitModule(LightningModule):
|
|
| 85 |
x = batch["image"]
|
| 86 |
y = batch["focus_value"]
|
| 87 |
logits = self.forward(x)
|
| 88 |
-
loss = self.criterion(logits, y)
|
| 89 |
preds = torch.squeeze(logits)
|
| 90 |
return loss, preds, y
|
| 91 |
|
|
@@ -210,7 +210,7 @@ class FocusMSELitModule(LightningModule):
|
|
| 210 |
x = batch["image"]
|
| 211 |
y = batch["focus_value"]
|
| 212 |
logits = self.forward(x)
|
| 213 |
-
loss = self.criterion(logits, y)
|
| 214 |
preds = torch.squeeze(logits)
|
| 215 |
return loss, preds, y
|
| 216 |
|
|
|
|
| 85 |
x = batch["image"]
|
| 86 |
y = batch["focus_value"]
|
| 87 |
logits = self.forward(x)
|
| 88 |
+
loss = self.criterion(logits, y.unsqueeze(1))
|
| 89 |
preds = torch.squeeze(logits)
|
| 90 |
return loss, preds, y
|
| 91 |
|
|
|
|
| 210 |
x = batch["image"]
|
| 211 |
y = batch["focus_value"]
|
| 212 |
logits = self.forward(x)
|
| 213 |
+
loss = self.criterion(logits, y.unsqueeze(1))
|
| 214 |
preds = torch.squeeze(logits)
|
| 215 |
return loss, preds, y
|
| 216 |
|