File size: 1,974 Bytes
c3f6f40
 
 
 
 
 
 
 
 
a490849
c3f6f40
 
 
 
d8d5dd8
c3f6f40
 
 
 
a490849
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a31ea8
a490849
 
 
 
c3f6f40
 
 
 
 
 
 
 
 
 
a490849
c3f6f40
 
 
 
 
 
 
 
a490849
 
 
 
c3f6f40
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from typing import Any, List

import torch
import torch.nn.functional as F
from torch import nn
from pytorch_lightning import LightningModule
from torchmetrics import MaxMetric, MeanAbsoluteError, MinMetric
from torchmetrics.classification.accuracy import Accuracy
import torchvision.models as models
import kornia


def vol4(img):
    img_grey = torch.mean(img, dim=0)
    return 100 / torch.sum(torch.mul(img_grey[:-1, :], img_grey[1:, :])) - torch.sum(
        torch.mul(img_grey[:-2, :], img_grey[2:, :])
    )


def laplacian(img):
    img_grey = torch.mean(img, dim=0).unsqueeze(0)
    filtered = kornia.filters.laplacian(img_grey, 3)
    mean = torch.mean(filtered)
    return 100 / mean  # invert mean to fit metric of lower = better


def midfrequency_dct(img):
    kernel = torch.tensor(
        [
            [
                [1, 1, -1, -1],
                [1, 1, -1, -1],
                [-1, -1, 1, 1],
                [-1, -1, 1, 1],
            ]
        ]
    )

    img_grey = torch.mean(img, dim=0).unsqueeze(0)
    filtered = torch.square(kornia.filters.filter2d(img_grey, kernel))
    sum = torch.sum(filtered)
    return 100 / sum


class TraditionalLitModule(LightningModule):
    def __init__(
        self,
        method: str = "vol4",
    ):

        """Initialize function for a traditional focus measurement `model`. It cannot be trained.

        Args:
            method (str, optional): The method to use for predicting focus. Defaults to "vol4".
                    Possible values are: vol4, mean_laplacian, midfrequency_dct

        Raises:
            Exception: raises exception if method parameter is not known
        """
        super().__init__()

        if method == "vol4":
            self.function = vol4
        if method == "mean_laplacian":
            self.function = laplacian
        if method == "midfrequency_dct":
            self.function = midfrequency_dct

    def forward(self, x):
        return self.function(x)