Spaces:
Sleeping
Sleeping
from typing import List, Tuple | |
import os | |
import glob | |
import numpy as np | |
import pandas as pd | |
from PIL import Image | |
from scipy.ndimage.filters import gaussian_filter, median_filter, rank_filter | |
from torch.utils.data import Dataset | |
from torchvision import transforms | |
from utils.constants import Split, Columns, CropsColumns, ProbsColumns | |
from utils.paths import CROPS_DATASET, CROPS_PATH, COORDS_PATH, IMG_PATH, PROBS_DATASET, PROBS_PATH, HAADF_DATASET, PT_DATASET | |
class ImageClassificationDataset(Dataset): | |
def __init__(self, image_paths, image_labels, include_filename=False): | |
self.image_paths = image_paths | |
self.image_labels = image_labels | |
self.include_filename = include_filename | |
self.transform = transforms.Compose([ | |
transforms.ToTensor() | |
# transforms.Normalize(mean=[0.5], std=[0.5]) | |
]) | |
def get_n_labels(self): | |
return len(set(self.image_labels)) | |
def __len__(self): | |
return len(self.image_paths) | |
def load_image(img_filename): | |
img = Image.open(img_filename) | |
np_img = np.asarray(img).astype(np.float32) | |
np_bg = median_filter(np_img, size=(40, 40)) | |
np_clean = np_img - np_bg | |
np_normed = (np_clean - np_clean.min()) / (np_clean.max() - np_clean.min()) | |
return np_normed | |
def __getitem__(self, idx): | |
img_path = self.image_paths[idx] | |
image = self.load_image(img_path) | |
image = self.transform(image) | |
label = self.image_labels[idx] | |
if self.include_filename: | |
return image, label, os.path.basename(img_path) | |
else: | |
return image, label | |
def get_filenames_labels(split: Split) -> Tuple[List[str], List[int]]: | |
raise NotImplementedError | |
def train_dataset(cls, **kwargs): | |
filenames, labels = cls.get_filenames_labels(Split.TRAIN) | |
return cls(filenames, labels, **kwargs) | |
def val_dataset(cls, **kwargs): | |
filenames, labels = cls.get_filenames_labels(Split.VAL) | |
return cls(filenames, labels, **kwargs) | |
def test_dataset(cls, **kwargs): | |
filenames, labels = cls.get_filenames_labels(Split.TEST) | |
return cls(filenames, labels, **kwargs) | |
class HaadfDataset(ImageClassificationDataset): | |
def get_filenames_labels(split: Split) -> Tuple[List[str], List[int]]: | |
df = pd.read_csv(HAADF_DATASET) | |
split_df = df[df[Columns.SPLIT] == split] | |
filenames = (IMG_PATH + os.sep + split_df[Columns.FILENAME]).to_list() | |
labels = (split_df[Columns.LABEL]).to_list() | |
return filenames, labels | |
class ImageDataset: | |
FILENAME_COL = "Filename" | |
SPLIT_COL = "Split" | |
RULER_UNITS = "Ruler Units" | |
def __init__(self, dataset_csv: str): | |
self.df = pd.read_csv(dataset_csv) | |
def iterate_data(self, split: Split): | |
df = self.df[self.df[self.SPLIT_COL] == split] | |
for idx, row in df.iterrows(): | |
image_filename = os.path.join(IMG_PATH, row[self.FILENAME_COL]) | |
yield image_filename | |
def get_ruler_units_by_img_name(self, name): | |
print(name) | |
return self.df[self.df[self.FILENAME_COL] == name][self.RULER_UNITS].values[0] | |
class CoordinatesDataset: | |
FILENAME_COL = "Filename" | |
COORDS_COL = "Coords" | |
SPLIT_COL = "Split" | |
def __init__(self, coord_image_csv: str): | |
self.df = pd.read_csv(coord_image_csv) | |
def iterate_data(self, split: Split): | |
df = self.df[self.df[self.SPLIT_COL] == split] | |
for idx, row in df.iterrows(): | |
image_filename = os.path.join(IMG_PATH, row[self.FILENAME_COL]) | |
if isinstance(row[self.COORDS_COL], str): | |
coords_filename = os.path.join(COORDS_PATH, row[self.COORDS_COL]) | |
else: | |
coords_filename = None | |
yield image_filename, coords_filename | |
def load_coordinates(label_filename: str) -> List[Tuple[int, int]]: | |
atom_coordinates = pd.read_csv(label_filename) | |
return list(zip(atom_coordinates['X'], atom_coordinates['Y'])) | |
def split_length(self, split: Split): | |
df = self.df[self.df[self.SPLIT_COL] == split] | |
return len(df) | |
class HaadfCoordinates(CoordinatesDataset): | |
def __init__(self): | |
super().__init__(coord_image_csv=PT_DATASET) | |
class CropsDataset(ImageClassificationDataset): | |
def get_filenames_labels(split: Split): | |
df = pd.read_csv(CROPS_DATASET) | |
split_df = df[df[CropsColumns.SPLIT] == split] | |
filenames = (CROPS_PATH + os.sep + split_df[CropsColumns.FILENAME]).to_list() | |
labels = (split_df[CropsColumns.LABEL]).to_list() | |
return filenames, labels | |
class CropsCustomDataset(ImageClassificationDataset): | |
def get_filenames_labels(split: Split, crops_dataset: str, crops_path: str): | |
df = pd.read_csv(crops_dataset) | |
split_df = df[df[CropsColumns.SPLIT] == split] | |
filenames = (crops_path + os.sep + split_df[CropsColumns.FILENAME]).to_list() | |
labels = (split_df[CropsColumns.LABEL]).to_list() | |
return filenames, labels | |
class ProbsDataset(ImageClassificationDataset): | |
def get_filenames_labels(split: Split): | |
df = pd.read_csv(PROBS_DATASET) | |
split_df = df[df[ProbsColumns.SPLIT] == split] | |
filenames = (PROBS_PATH + os.sep + split_df[ProbsColumns.FILENAME]).to_list() | |
labels = (split_df[ProbsColumns.LABEL]).to_list() | |
return filenames, labels | |
class SlidingCropDataset(Dataset): | |
def __init__(self, tif_filename, include_coords=True): | |
self.filename = tif_filename | |
self.include_coords = include_coords | |
self.transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5], std=[0.5]) | |
]) | |
self.n_labels = 2 | |
self.step_size = 2 | |
self.window_size = (21, 21) | |
self.loaded_crops = [] | |
self.loaded_coords = [] | |
self.load_crops() | |
def sliding_window(self, image): | |
# slide a window across the image | |
for x in range(0, image.shape[0] - self.window_size[0], self.step_size): | |
for y in range(0, image.shape[1] - self.window_size[1], self.step_size): | |
# yield the current window | |
center_x = x + ((self.window_size[0] - 1) // 2) | |
center_y = y + ((self.window_size[1] - 1) // 2) | |
yield center_x, center_y, image[x:x + self.window_size[0], y:y + self.window_size[1]] | |
def load_image(img_filename): | |
img = Image.open(img_filename) | |
np_img = np.asarray(img).astype(np.float32) | |
np_bg = median_filter(np_img, size=(40, 40)) | |
np_clean = np_img - np_bg | |
np_normed = (np_clean - np_clean.min()) / (np_clean.max() - np_clean.min()) | |
return np_normed | |
def load_crops(self): | |
img = self.load_image(self.filename) | |
for x_center, y_center, img_crop in self.sliding_window(img): | |
self.loaded_crops.append(img_crop) | |
self.loaded_coords.append((x_center, y_center)) | |
def get_n_labels(self): | |
return self.n_labels | |
def __len__(self): | |
return len(self.loaded_crops) | |
def __getitem__(self, idx): | |
crop = self.loaded_crops[idx] | |
x, y = self.loaded_coords[idx] | |
crop = self.transform(crop) | |
return crop, x, y | |
def get_image_path_without_coords(split: str or None = None): | |
coords_prefix_set = set() | |
for coords_name in os.listdir(COORDS_PATH): | |
coord_prefix = coords_name.split('_')[0] | |
coords_prefix_set.add(coord_prefix) | |
all_prefixes_set = set() | |
for tif_name in os.listdir(IMG_PATH): | |
coord_prefix = tif_name.split('_')[0] | |
all_prefixes_set.add(coord_prefix) | |
if split == Split.TRAIN: | |
missing_prefixes = coords_prefix_set | |
elif split == Split.TEST: | |
missing_prefixes = all_prefixes_set - coords_prefix_set | |
elif split is None: | |
missing_prefixes = all_prefixes_set | |
else: | |
raise ValueError | |
tif_filenames_list = [] | |
labels_list = [] | |
for prefix in missing_prefixes: | |
filename_matches = glob.glob(os.path.join(IMG_PATH, f'{prefix}_HAADF*NC*')) | |
if len(filename_matches) == 0: | |
continue | |
pos_filenames = [filename for filename in filename_matches if '_PtNC' in filename] | |
neg_filenames = [filename for filename in filename_matches if '_NC' in filename] | |
if len(pos_filenames) > 0: | |
pos_filename = sorted(pos_filenames)[-1] | |
tif_filenames_list.append(pos_filename) | |
labels_list.append(1) | |
if len(neg_filenames) > 0: | |
neg_filename = sorted(neg_filenames)[-1] | |
tif_filenames_list.append(neg_filename) | |
labels_list.append(0) | |
return tif_filenames_list, labels_list | |
if __name__ == "__main__": | |
filenames_list = get_image_path_without_coords() | |
filename = filenames_list[0] | |
dataset = SlidingCropDataset(filename) | |