from typing import List, Tuple import os import glob import numpy as np import pandas as pd from PIL import Image from scipy.ndimage.filters import gaussian_filter, median_filter, rank_filter from torch.utils.data import Dataset from torchvision import transforms from utils.constants import Split, Columns, CropsColumns, ProbsColumns from utils.paths import CROPS_DATASET, CROPS_PATH, COORDS_PATH, IMG_PATH, PROBS_DATASET, PROBS_PATH, HAADF_DATASET, PT_DATASET class ImageClassificationDataset(Dataset): def __init__(self, image_paths, image_labels, include_filename=False): self.image_paths = image_paths self.image_labels = image_labels self.include_filename = include_filename self.transform = transforms.Compose([ transforms.ToTensor() # transforms.Normalize(mean=[0.5], std=[0.5]) ]) def get_n_labels(self): return len(set(self.image_labels)) def __len__(self): return len(self.image_paths) @staticmethod def load_image(img_filename): img = Image.open(img_filename) np_img = np.asarray(img).astype(np.float32) np_bg = median_filter(np_img, size=(40, 40)) np_clean = np_img - np_bg np_normed = (np_clean - np_clean.min()) / (np_clean.max() - np_clean.min()) return np_normed def __getitem__(self, idx): img_path = self.image_paths[idx] image = self.load_image(img_path) image = self.transform(image) label = self.image_labels[idx] if self.include_filename: return image, label, os.path.basename(img_path) else: return image, label @staticmethod def get_filenames_labels(split: Split) -> Tuple[List[str], List[int]]: raise NotImplementedError @classmethod def train_dataset(cls, **kwargs): filenames, labels = cls.get_filenames_labels(Split.TRAIN) return cls(filenames, labels, **kwargs) @classmethod def val_dataset(cls, **kwargs): filenames, labels = cls.get_filenames_labels(Split.VAL) return cls(filenames, labels, **kwargs) @classmethod def test_dataset(cls, **kwargs): filenames, labels = cls.get_filenames_labels(Split.TEST) return cls(filenames, labels, **kwargs) class HaadfDataset(ImageClassificationDataset): @staticmethod def get_filenames_labels(split: Split) -> Tuple[List[str], List[int]]: df = pd.read_csv(HAADF_DATASET) split_df = df[df[Columns.SPLIT] == split] filenames = (IMG_PATH + os.sep + split_df[Columns.FILENAME]).to_list() labels = (split_df[Columns.LABEL]).to_list() return filenames, labels class ImageDataset: FILENAME_COL = "Filename" SPLIT_COL = "Split" RULER_UNITS = "Ruler Units" def __init__(self, dataset_csv: str): self.df = pd.read_csv(dataset_csv) def iterate_data(self, split: Split): df = self.df[self.df[self.SPLIT_COL] == split] for idx, row in df.iterrows(): image_filename = os.path.join(IMG_PATH, row[self.FILENAME_COL]) yield image_filename def get_ruler_units_by_img_name(self, name): print(name) return self.df[self.df[self.FILENAME_COL] == name][self.RULER_UNITS].values[0] class CoordinatesDataset: FILENAME_COL = "Filename" COORDS_COL = "Coords" SPLIT_COL = "Split" def __init__(self, coord_image_csv: str): self.df = pd.read_csv(coord_image_csv) def iterate_data(self, split: Split): df = self.df[self.df[self.SPLIT_COL] == split] for idx, row in df.iterrows(): image_filename = os.path.join(IMG_PATH, row[self.FILENAME_COL]) if isinstance(row[self.COORDS_COL], str): coords_filename = os.path.join(COORDS_PATH, row[self.COORDS_COL]) else: coords_filename = None yield image_filename, coords_filename @staticmethod def load_coordinates(label_filename: str) -> List[Tuple[int, int]]: atom_coordinates = pd.read_csv(label_filename) return list(zip(atom_coordinates['X'], atom_coordinates['Y'])) def split_length(self, split: Split): df = self.df[self.df[self.SPLIT_COL] == split] return len(df) class HaadfCoordinates(CoordinatesDataset): def __init__(self): super().__init__(coord_image_csv=PT_DATASET) class CropsDataset(ImageClassificationDataset): @staticmethod def get_filenames_labels(split: Split): df = pd.read_csv(CROPS_DATASET) split_df = df[df[CropsColumns.SPLIT] == split] filenames = (CROPS_PATH + os.sep + split_df[CropsColumns.FILENAME]).to_list() labels = (split_df[CropsColumns.LABEL]).to_list() return filenames, labels class CropsCustomDataset(ImageClassificationDataset): @staticmethod def get_filenames_labels(split: Split, crops_dataset: str, crops_path: str): df = pd.read_csv(crops_dataset) split_df = df[df[CropsColumns.SPLIT] == split] filenames = (crops_path + os.sep + split_df[CropsColumns.FILENAME]).to_list() labels = (split_df[CropsColumns.LABEL]).to_list() return filenames, labels class ProbsDataset(ImageClassificationDataset): @staticmethod def get_filenames_labels(split: Split): df = pd.read_csv(PROBS_DATASET) split_df = df[df[ProbsColumns.SPLIT] == split] filenames = (PROBS_PATH + os.sep + split_df[ProbsColumns.FILENAME]).to_list() labels = (split_df[ProbsColumns.LABEL]).to_list() return filenames, labels class SlidingCropDataset(Dataset): def __init__(self, tif_filename, include_coords=True): self.filename = tif_filename self.include_coords = include_coords self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5]) ]) self.n_labels = 2 self.step_size = 2 self.window_size = (21, 21) self.loaded_crops = [] self.loaded_coords = [] self.load_crops() def sliding_window(self, image): # slide a window across the image for x in range(0, image.shape[0] - self.window_size[0], self.step_size): for y in range(0, image.shape[1] - self.window_size[1], self.step_size): # yield the current window center_x = x + ((self.window_size[0] - 1) // 2) center_y = y + ((self.window_size[1] - 1) // 2) yield center_x, center_y, image[x:x + self.window_size[0], y:y + self.window_size[1]] @staticmethod def load_image(img_filename): img = Image.open(img_filename) np_img = np.asarray(img).astype(np.float32) np_bg = median_filter(np_img, size=(40, 40)) np_clean = np_img - np_bg np_normed = (np_clean - np_clean.min()) / (np_clean.max() - np_clean.min()) return np_normed def load_crops(self): img = self.load_image(self.filename) for x_center, y_center, img_crop in self.sliding_window(img): self.loaded_crops.append(img_crop) self.loaded_coords.append((x_center, y_center)) def get_n_labels(self): return self.n_labels def __len__(self): return len(self.loaded_crops) def __getitem__(self, idx): crop = self.loaded_crops[idx] x, y = self.loaded_coords[idx] crop = self.transform(crop) return crop, x, y def get_image_path_without_coords(split: str or None = None): coords_prefix_set = set() for coords_name in os.listdir(COORDS_PATH): coord_prefix = coords_name.split('_')[0] coords_prefix_set.add(coord_prefix) all_prefixes_set = set() for tif_name in os.listdir(IMG_PATH): coord_prefix = tif_name.split('_')[0] all_prefixes_set.add(coord_prefix) if split == Split.TRAIN: missing_prefixes = coords_prefix_set elif split == Split.TEST: missing_prefixes = all_prefixes_set - coords_prefix_set elif split is None: missing_prefixes = all_prefixes_set else: raise ValueError tif_filenames_list = [] labels_list = [] for prefix in missing_prefixes: filename_matches = glob.glob(os.path.join(IMG_PATH, f'{prefix}_HAADF*NC*')) if len(filename_matches) == 0: continue pos_filenames = [filename for filename in filename_matches if '_PtNC' in filename] neg_filenames = [filename for filename in filename_matches if '_NC' in filename] if len(pos_filenames) > 0: pos_filename = sorted(pos_filenames)[-1] tif_filenames_list.append(pos_filename) labels_list.append(1) if len(neg_filenames) > 0: neg_filename = sorted(neg_filenames)[-1] tif_filenames_list.append(neg_filename) labels_list.append(0) return tif_filenames_list, labels_list if __name__ == "__main__": filenames_list = get_image_path_without_coords() filename = filenames_list[0] dataset = SlidingCropDataset(filename)