File size: 3,778 Bytes
4b3d085
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import albumentations as A
import torch
import torch.nn as nn
import torch.nn.functional as F

from numpy.typing import NDArray
from transformers import PreTrainedModel
from timm import create_model
from typing import Optional

from .configuration import MammoCropConfig

_PYDICOM_AVAILABLE = False
try:
    from pydicom import dcmread
    from pydicom.pixels import apply_voi_lut

    _PYDICOM_AVAILABLE = True
except ModuleNotFoundError:
    pass


class GeM(nn.Module):
    def __init__(
        self, p: int = 3, eps: float = 1e-6, dim: int = 2, flatten: bool = True
    ):
        super().__init__()
        self.p = nn.Parameter(torch.ones(1) * p)
        self.eps = eps
        assert dim in {2, 3}, f"dim must be one of [2, 3], not {dim}"
        self.dim = dim
        if self.dim == 2:
            self.func = F.adaptive_avg_pool2d
        elif self.dim == 3:
            self.func = F.adaptive_avg_pool3d
        self.flatten = nn.Flatten(1) if flatten else nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # assumes x.shape is (n, c, [t], h, w)
        x = self.func(x.clamp(min=self.eps).pow(self.p), output_size=1).pow(
            1.0 / self.p
        )
        return self.flatten(x)


class MammoCropModel(PreTrainedModel):
    config_class = MammoCropConfig

    def __init__(self, config):
        super().__init__(config)
        self.backbone = create_model(
            model_name=config.backbone,
            pretrained=False,
            num_classes=0,
            global_pool="",
            features_only=False,
            in_chans=config.in_chans,
        )
        self.pooling = GeM(p=3, dim=2)
        self.dropout = nn.Dropout(p=config.dropout)
        self.linear = nn.Linear(config.feature_dim, config.num_classes)

    def normalize(self, x: torch.Tensor) -> torch.Tensor:
        # [0, 255] -> [-1, 1]
        mini, maxi = 0.0, 255.0
        x = (x - mini) / (maxi - mini)
        x = (x - 0.5) * 2.0
        return x

    @staticmethod
    def load_image_from_dicom(path: str) -> Optional[NDArray]:
        if not _PYDICOM_AVAILABLE:
            print("`pydicom` is not installed, returning None ...")
            return None
        dicom = dcmread(path)
        arr = apply_voi_lut(dicom.pixel_array, dicom)
        if dicom.PhotometricInterpretation == "MONOCHROME1":
            # invert image if needed
            arr = arr.max() - arr

        arr = arr - arr.min()
        arr = arr / arr.max()
        arr = (arr * 255).astype("uint8")
        return arr

    @staticmethod
    def preprocess(x: NDArray) -> NDArray:
        return A.Resize(256, 256, p=1)(image=x)["image"]

    def forward(
        self, x: torch.Tensor, img_shape: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        # if img_shape is provided, will provide rescaled coordinates
        # otherwise, provide normalized [0, 1] coordinates
        # coords format is xywh
        if img_shape is not None:
            assert (
                x.size(0) == img_shape.size(0)
            ), f"x.size(0) [{x.size(0)}] must equal img_shape.size(0) [{img_shape.size(0)}]"
            # img_shape = (batch_dim, 2)
            # img_shape[:, 0] = height, img_shape[:, 1] = width

        x = self.normalize(x)
        features = self.pooling(self.backbone(x))
        coords = self.linear(features).sigmoid()

        if img_shape is None:
            return coords

        rescaled_coords = coords.clone()
        rescaled_coords[:, 0] = rescaled_coords[:, 0] * img_shape[:, 1]
        rescaled_coords[:, 1] = rescaled_coords[:, 1] * img_shape[:, 0]
        rescaled_coords[:, 2] = rescaled_coords[:, 2] * img_shape[:, 1]
        rescaled_coords[:, 3] = rescaled_coords[:, 3] * img_shape[:, 0]
        return rescaled_coords.int()