File size: 9,156 Bytes
b2ffc9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
from typing import List, Tuple
import os
import glob

import numpy as np
import pandas as pd
from PIL import Image
from scipy.ndimage.filters import gaussian_filter, median_filter, rank_filter
from torch.utils.data import Dataset
from torchvision import transforms

from utils.constants import Split, Columns, CropsColumns, ProbsColumns
from utils.paths import CROPS_DATASET, CROPS_PATH, COORDS_PATH, IMG_PATH, PROBS_DATASET, PROBS_PATH, HAADF_DATASET, PT_DATASET


class ImageClassificationDataset(Dataset):

    def __init__(self, image_paths, image_labels, include_filename=False):
        self.image_paths = image_paths
        self.image_labels = image_labels
        self.include_filename = include_filename
        self.transform = transforms.Compose([
            transforms.ToTensor()
            # transforms.Normalize(mean=[0.5], std=[0.5])
        ])

    def get_n_labels(self):
        return len(set(self.image_labels))

    def __len__(self):
        return len(self.image_paths)

    @staticmethod
    def load_image(img_filename):
        img = Image.open(img_filename)
        np_img = np.asarray(img).astype(np.float32)
        np_bg = median_filter(np_img, size=(40, 40))
        np_clean = np_img - np_bg
        np_normed = (np_clean - np_clean.min()) / (np_clean.max() - np_clean.min())
        return np_normed

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = self.load_image(img_path)
        image = self.transform(image)
        label = self.image_labels[idx]

        if self.include_filename:
            return image, label, os.path.basename(img_path)
        else:
            return image, label

    @staticmethod
    def get_filenames_labels(split: Split) -> Tuple[List[str], List[int]]:
        raise NotImplementedError

    @classmethod
    def train_dataset(cls, **kwargs):
        filenames, labels = cls.get_filenames_labels(Split.TRAIN)
        return cls(filenames, labels, **kwargs)

    @classmethod
    def val_dataset(cls, **kwargs):
        filenames, labels = cls.get_filenames_labels(Split.VAL)
        return cls(filenames, labels, **kwargs)

    @classmethod
    def test_dataset(cls, **kwargs):
        filenames, labels = cls.get_filenames_labels(Split.TEST)
        return cls(filenames, labels, **kwargs)


class HaadfDataset(ImageClassificationDataset):
    @staticmethod
    def get_filenames_labels(split: Split) -> Tuple[List[str], List[int]]:
        df = pd.read_csv(HAADF_DATASET)
        split_df = df[df[Columns.SPLIT] == split]
        filenames = (IMG_PATH + os.sep + split_df[Columns.FILENAME]).to_list()
        labels = (split_df[Columns.LABEL]).to_list()
        return filenames, labels


class ImageDataset:
    FILENAME_COL = "Filename"
    SPLIT_COL = "Split"
    RULER_UNITS = "Ruler Units"

    def __init__(self, dataset_csv: str):
        self.df = pd.read_csv(dataset_csv)

    def iterate_data(self, split: Split):
        df = self.df[self.df[self.SPLIT_COL] == split]
        for idx, row in df.iterrows():
            image_filename = os.path.join(IMG_PATH, row[self.FILENAME_COL])
            yield image_filename

    def get_ruler_units_by_img_name(self, name):
        print(name)
        return self.df[self.df[self.FILENAME_COL] == name][self.RULER_UNITS].values[0]
        


class CoordinatesDataset:
    FILENAME_COL = "Filename"
    COORDS_COL = "Coords"
    SPLIT_COL = "Split"

    def __init__(self, coord_image_csv: str):
        self.df = pd.read_csv(coord_image_csv)

    def iterate_data(self, split: Split):
        df = self.df[self.df[self.SPLIT_COL] == split]
        for idx, row in df.iterrows():
            image_filename = os.path.join(IMG_PATH, row[self.FILENAME_COL])
            if isinstance(row[self.COORDS_COL], str):
                coords_filename = os.path.join(COORDS_PATH, row[self.COORDS_COL])
            else:
                coords_filename = None
            yield image_filename, coords_filename

    @staticmethod
    def load_coordinates(label_filename: str) -> List[Tuple[int, int]]:
        atom_coordinates = pd.read_csv(label_filename)
        return list(zip(atom_coordinates['X'], atom_coordinates['Y']))

    def split_length(self, split: Split):
        df = self.df[self.df[self.SPLIT_COL] == split]
        return len(df)


class HaadfCoordinates(CoordinatesDataset):
    def __init__(self):
        super().__init__(coord_image_csv=PT_DATASET)


class CropsDataset(ImageClassificationDataset):
    @staticmethod
    def get_filenames_labels(split: Split):
        df = pd.read_csv(CROPS_DATASET)
        split_df = df[df[CropsColumns.SPLIT] == split]
        filenames = (CROPS_PATH + os.sep + split_df[CropsColumns.FILENAME]).to_list()
        labels = (split_df[CropsColumns.LABEL]).to_list()
        return filenames, labels


class CropsCustomDataset(ImageClassificationDataset):

    @staticmethod
    def get_filenames_labels(split: Split, crops_dataset: str, crops_path: str):
        df = pd.read_csv(crops_dataset)
        split_df = df[df[CropsColumns.SPLIT] == split]
        filenames = (crops_path + os.sep + split_df[CropsColumns.FILENAME]).to_list()
        labels = (split_df[CropsColumns.LABEL]).to_list()
        return filenames, labels


class ProbsDataset(ImageClassificationDataset):
    @staticmethod
    def get_filenames_labels(split: Split):
        df = pd.read_csv(PROBS_DATASET)
        split_df = df[df[ProbsColumns.SPLIT] == split]
        filenames = (PROBS_PATH + os.sep + split_df[ProbsColumns.FILENAME]).to_list()
        labels = (split_df[ProbsColumns.LABEL]).to_list()
        return filenames, labels


class SlidingCropDataset(Dataset):

    def __init__(self, tif_filename, include_coords=True):
        self.filename = tif_filename
        self.include_coords = include_coords
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])

        self.n_labels = 2
        self.step_size = 2
        self.window_size = (21, 21)
        self.loaded_crops = []
        self.loaded_coords = []
        self.load_crops()

    def sliding_window(self, image):
        # slide a window across the image
        for x in range(0, image.shape[0] - self.window_size[0], self.step_size):
            for y in range(0, image.shape[1] - self.window_size[1], self.step_size):
                # yield the current window
                center_x = x + ((self.window_size[0] - 1) // 2)
                center_y = y + ((self.window_size[1] - 1) // 2)
                yield center_x, center_y, image[x:x + self.window_size[0], y:y + self.window_size[1]]

    @staticmethod
    def load_image(img_filename):
        img = Image.open(img_filename)
        np_img = np.asarray(img).astype(np.float32)
        np_bg = median_filter(np_img, size=(40, 40))
        np_clean = np_img - np_bg
        np_normed = (np_clean - np_clean.min()) / (np_clean.max() - np_clean.min())
        return np_normed

    def load_crops(self):
        img = self.load_image(self.filename)
        for x_center, y_center, img_crop in self.sliding_window(img):
            self.loaded_crops.append(img_crop)
            self.loaded_coords.append((x_center, y_center))

    def get_n_labels(self):
        return self.n_labels

    def __len__(self):
        return len(self.loaded_crops)

    def __getitem__(self, idx):
        crop = self.loaded_crops[idx]
        x, y = self.loaded_coords[idx]
        crop = self.transform(crop)

        return crop, x, y


def get_image_path_without_coords(split: str or None = None):
    coords_prefix_set = set()
    for coords_name in os.listdir(COORDS_PATH):
        coord_prefix = coords_name.split('_')[0]
        coords_prefix_set.add(coord_prefix)

    all_prefixes_set = set()
    for tif_name in os.listdir(IMG_PATH):
        coord_prefix = tif_name.split('_')[0]
        all_prefixes_set.add(coord_prefix)

    if split == Split.TRAIN:
        missing_prefixes = coords_prefix_set
    elif split == Split.TEST:
        missing_prefixes = all_prefixes_set - coords_prefix_set
    elif split is None:
        missing_prefixes = all_prefixes_set
    else:
        raise ValueError
    tif_filenames_list = []
    labels_list = []
    for prefix in missing_prefixes:
        filename_matches = glob.glob(os.path.join(IMG_PATH, f'{prefix}_HAADF*NC*'))
        if len(filename_matches) == 0:
            continue
        pos_filenames = [filename for filename in filename_matches if '_PtNC' in filename]
        neg_filenames = [filename for filename in filename_matches if '_NC' in filename]

        if len(pos_filenames) > 0:
            pos_filename = sorted(pos_filenames)[-1]
            tif_filenames_list.append(pos_filename)
            labels_list.append(1)
        if len(neg_filenames) > 0:
            neg_filename = sorted(neg_filenames)[-1]
            tif_filenames_list.append(neg_filename)
            labels_list.append(0)

    return tif_filenames_list, labels_list


if __name__ == "__main__":
    filenames_list = get_image_path_without_coords()
    filename = filenames_list[0]
    dataset = SlidingCropDataset(filename)