Hannes Kuchelmeister commited on
Commit
c3f6f40
1 Parent(s): e3821c2

add traditional method for comparison

Browse files
notebooks/5.0-hfk-comparing-to-traditional.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:733ce4b03b6efa6a4a0996783c8d88dcd64e9e269b16bfcb747e81dfd78b5743
3
+ size 11339
src/models/focus_traditional.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ from pytorch_lightning import LightningModule
7
+ from torchmetrics import MaxMetric, MeanAbsoluteError, MinMetric
8
+ from torchmetrics.classification.accuracy import Accuracy
9
+ import torchvision.models as models
10
+
11
+
12
+ def vol4(img):
13
+ img_grey = torch.mean(img, dim=0)
14
+ return torch.sum(torch.mul(img_grey[:-1, :], img_grey[1:, :])) - torch.sum(
15
+ torch.mul(img_grey[:-2, :], img_grey[2:, :])
16
+ )
17
+
18
+
19
+ class TraditionalLitModule(LightningModule):
20
+ def __init__(
21
+ self,
22
+ method: str = "vol4",
23
+ ):
24
+
25
+ """Initialize function for a traditional focus measurement `model`. It cannot be trained.
26
+
27
+ Args:
28
+ method (str, optional): The method to use for predicting focus. Defaults to "vol4".
29
+
30
+ Raises:
31
+ Exception: raises exception if method parameter is not known
32
+ """
33
+ super().__init__()
34
+
35
+ if method == "vol4":
36
+ self.function = vol4
37
+
38
+ def forward(self, x):
39
+ return self.function(x)