Hannes Kuchelmeister commited on
Commit
ba9c868
1 Parent(s): ad947ed

Add code so loss function uses torch.Size([x,1]) instead of torch.Size([x])

Browse files

This was done to prevent the error message:
"Using a target size (torch.Size([64])) that is different to the input size (torch.Size([64, 1]))".

Files changed (1) hide show
  1. src/models/focus_module.py +2 -2
src/models/focus_module.py CHANGED
@@ -85,7 +85,7 @@ class FocusLitModule(LightningModule):
85
  x = batch["image"]
86
  y = batch["focus_value"]
87
  logits = self.forward(x)
88
- loss = self.criterion(logits, y)
89
  preds = torch.squeeze(logits)
90
  return loss, preds, y
91
 
@@ -210,7 +210,7 @@ class FocusMSELitModule(LightningModule):
210
  x = batch["image"]
211
  y = batch["focus_value"]
212
  logits = self.forward(x)
213
- loss = self.criterion(logits, y)
214
  preds = torch.squeeze(logits)
215
  return loss, preds, y
216
 
 
85
  x = batch["image"]
86
  y = batch["focus_value"]
87
  logits = self.forward(x)
88
+ loss = self.criterion(logits, y.unsqueeze(1))
89
  preds = torch.squeeze(logits)
90
  return loss, preds, y
91
 
 
210
  x = batch["image"]
211
  y = batch["focus_value"]
212
  logits = self.forward(x)
213
+ loss = self.criterion(logits, y.unsqueeze(1))
214
  preds = torch.squeeze(logits)
215
  return loss, preds, y
216