File size: 1,974 Bytes
c3f6f40 a490849 c3f6f40 d8d5dd8 c3f6f40 a490849 1a31ea8 a490849 c3f6f40 a490849 c3f6f40 a490849 c3f6f40 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
from typing import Any, List
import torch
import torch.nn.functional as F
from torch import nn
from pytorch_lightning import LightningModule
from torchmetrics import MaxMetric, MeanAbsoluteError, MinMetric
from torchmetrics.classification.accuracy import Accuracy
import torchvision.models as models
import kornia
def vol4(img):
img_grey = torch.mean(img, dim=0)
return 100 / torch.sum(torch.mul(img_grey[:-1, :], img_grey[1:, :])) - torch.sum(
torch.mul(img_grey[:-2, :], img_grey[2:, :])
)
def laplacian(img):
img_grey = torch.mean(img, dim=0).unsqueeze(0)
filtered = kornia.filters.laplacian(img_grey, 3)
mean = torch.mean(filtered)
return 100 / mean # invert mean to fit metric of lower = better
def midfrequency_dct(img):
kernel = torch.tensor(
[
[
[1, 1, -1, -1],
[1, 1, -1, -1],
[-1, -1, 1, 1],
[-1, -1, 1, 1],
]
]
)
img_grey = torch.mean(img, dim=0).unsqueeze(0)
filtered = torch.square(kornia.filters.filter2d(img_grey, kernel))
sum = torch.sum(filtered)
return 100 / sum
class TraditionalLitModule(LightningModule):
def __init__(
self,
method: str = "vol4",
):
"""Initialize function for a traditional focus measurement `model`. It cannot be trained.
Args:
method (str, optional): The method to use for predicting focus. Defaults to "vol4".
Possible values are: vol4, mean_laplacian, midfrequency_dct
Raises:
Exception: raises exception if method parameter is not known
"""
super().__init__()
if method == "vol4":
self.function = vol4
if method == "mean_laplacian":
self.function = laplacian
if method == "midfrequency_dct":
self.function = midfrequency_dct
def forward(self, x):
return self.function(x)
|