Romain Graux
Initial commit with ml code and webapp
b2ffc9b
raw
history blame
9.16 kB
from typing import List, Tuple
import os
import glob
import numpy as np
import pandas as pd
from PIL import Image
from scipy.ndimage.filters import gaussian_filter, median_filter, rank_filter
from torch.utils.data import Dataset
from torchvision import transforms
from utils.constants import Split, Columns, CropsColumns, ProbsColumns
from utils.paths import CROPS_DATASET, CROPS_PATH, COORDS_PATH, IMG_PATH, PROBS_DATASET, PROBS_PATH, HAADF_DATASET, PT_DATASET
class ImageClassificationDataset(Dataset):
def __init__(self, image_paths, image_labels, include_filename=False):
self.image_paths = image_paths
self.image_labels = image_labels
self.include_filename = include_filename
self.transform = transforms.Compose([
transforms.ToTensor()
# transforms.Normalize(mean=[0.5], std=[0.5])
])
def get_n_labels(self):
return len(set(self.image_labels))
def __len__(self):
return len(self.image_paths)
@staticmethod
def load_image(img_filename):
img = Image.open(img_filename)
np_img = np.asarray(img).astype(np.float32)
np_bg = median_filter(np_img, size=(40, 40))
np_clean = np_img - np_bg
np_normed = (np_clean - np_clean.min()) / (np_clean.max() - np_clean.min())
return np_normed
def __getitem__(self, idx):
img_path = self.image_paths[idx]
image = self.load_image(img_path)
image = self.transform(image)
label = self.image_labels[idx]
if self.include_filename:
return image, label, os.path.basename(img_path)
else:
return image, label
@staticmethod
def get_filenames_labels(split: Split) -> Tuple[List[str], List[int]]:
raise NotImplementedError
@classmethod
def train_dataset(cls, **kwargs):
filenames, labels = cls.get_filenames_labels(Split.TRAIN)
return cls(filenames, labels, **kwargs)
@classmethod
def val_dataset(cls, **kwargs):
filenames, labels = cls.get_filenames_labels(Split.VAL)
return cls(filenames, labels, **kwargs)
@classmethod
def test_dataset(cls, **kwargs):
filenames, labels = cls.get_filenames_labels(Split.TEST)
return cls(filenames, labels, **kwargs)
class HaadfDataset(ImageClassificationDataset):
@staticmethod
def get_filenames_labels(split: Split) -> Tuple[List[str], List[int]]:
df = pd.read_csv(HAADF_DATASET)
split_df = df[df[Columns.SPLIT] == split]
filenames = (IMG_PATH + os.sep + split_df[Columns.FILENAME]).to_list()
labels = (split_df[Columns.LABEL]).to_list()
return filenames, labels
class ImageDataset:
FILENAME_COL = "Filename"
SPLIT_COL = "Split"
RULER_UNITS = "Ruler Units"
def __init__(self, dataset_csv: str):
self.df = pd.read_csv(dataset_csv)
def iterate_data(self, split: Split):
df = self.df[self.df[self.SPLIT_COL] == split]
for idx, row in df.iterrows():
image_filename = os.path.join(IMG_PATH, row[self.FILENAME_COL])
yield image_filename
def get_ruler_units_by_img_name(self, name):
print(name)
return self.df[self.df[self.FILENAME_COL] == name][self.RULER_UNITS].values[0]
class CoordinatesDataset:
FILENAME_COL = "Filename"
COORDS_COL = "Coords"
SPLIT_COL = "Split"
def __init__(self, coord_image_csv: str):
self.df = pd.read_csv(coord_image_csv)
def iterate_data(self, split: Split):
df = self.df[self.df[self.SPLIT_COL] == split]
for idx, row in df.iterrows():
image_filename = os.path.join(IMG_PATH, row[self.FILENAME_COL])
if isinstance(row[self.COORDS_COL], str):
coords_filename = os.path.join(COORDS_PATH, row[self.COORDS_COL])
else:
coords_filename = None
yield image_filename, coords_filename
@staticmethod
def load_coordinates(label_filename: str) -> List[Tuple[int, int]]:
atom_coordinates = pd.read_csv(label_filename)
return list(zip(atom_coordinates['X'], atom_coordinates['Y']))
def split_length(self, split: Split):
df = self.df[self.df[self.SPLIT_COL] == split]
return len(df)
class HaadfCoordinates(CoordinatesDataset):
def __init__(self):
super().__init__(coord_image_csv=PT_DATASET)
class CropsDataset(ImageClassificationDataset):
@staticmethod
def get_filenames_labels(split: Split):
df = pd.read_csv(CROPS_DATASET)
split_df = df[df[CropsColumns.SPLIT] == split]
filenames = (CROPS_PATH + os.sep + split_df[CropsColumns.FILENAME]).to_list()
labels = (split_df[CropsColumns.LABEL]).to_list()
return filenames, labels
class CropsCustomDataset(ImageClassificationDataset):
@staticmethod
def get_filenames_labels(split: Split, crops_dataset: str, crops_path: str):
df = pd.read_csv(crops_dataset)
split_df = df[df[CropsColumns.SPLIT] == split]
filenames = (crops_path + os.sep + split_df[CropsColumns.FILENAME]).to_list()
labels = (split_df[CropsColumns.LABEL]).to_list()
return filenames, labels
class ProbsDataset(ImageClassificationDataset):
@staticmethod
def get_filenames_labels(split: Split):
df = pd.read_csv(PROBS_DATASET)
split_df = df[df[ProbsColumns.SPLIT] == split]
filenames = (PROBS_PATH + os.sep + split_df[ProbsColumns.FILENAME]).to_list()
labels = (split_df[ProbsColumns.LABEL]).to_list()
return filenames, labels
class SlidingCropDataset(Dataset):
def __init__(self, tif_filename, include_coords=True):
self.filename = tif_filename
self.include_coords = include_coords
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
self.n_labels = 2
self.step_size = 2
self.window_size = (21, 21)
self.loaded_crops = []
self.loaded_coords = []
self.load_crops()
def sliding_window(self, image):
# slide a window across the image
for x in range(0, image.shape[0] - self.window_size[0], self.step_size):
for y in range(0, image.shape[1] - self.window_size[1], self.step_size):
# yield the current window
center_x = x + ((self.window_size[0] - 1) // 2)
center_y = y + ((self.window_size[1] - 1) // 2)
yield center_x, center_y, image[x:x + self.window_size[0], y:y + self.window_size[1]]
@staticmethod
def load_image(img_filename):
img = Image.open(img_filename)
np_img = np.asarray(img).astype(np.float32)
np_bg = median_filter(np_img, size=(40, 40))
np_clean = np_img - np_bg
np_normed = (np_clean - np_clean.min()) / (np_clean.max() - np_clean.min())
return np_normed
def load_crops(self):
img = self.load_image(self.filename)
for x_center, y_center, img_crop in self.sliding_window(img):
self.loaded_crops.append(img_crop)
self.loaded_coords.append((x_center, y_center))
def get_n_labels(self):
return self.n_labels
def __len__(self):
return len(self.loaded_crops)
def __getitem__(self, idx):
crop = self.loaded_crops[idx]
x, y = self.loaded_coords[idx]
crop = self.transform(crop)
return crop, x, y
def get_image_path_without_coords(split: str or None = None):
coords_prefix_set = set()
for coords_name in os.listdir(COORDS_PATH):
coord_prefix = coords_name.split('_')[0]
coords_prefix_set.add(coord_prefix)
all_prefixes_set = set()
for tif_name in os.listdir(IMG_PATH):
coord_prefix = tif_name.split('_')[0]
all_prefixes_set.add(coord_prefix)
if split == Split.TRAIN:
missing_prefixes = coords_prefix_set
elif split == Split.TEST:
missing_prefixes = all_prefixes_set - coords_prefix_set
elif split is None:
missing_prefixes = all_prefixes_set
else:
raise ValueError
tif_filenames_list = []
labels_list = []
for prefix in missing_prefixes:
filename_matches = glob.glob(os.path.join(IMG_PATH, f'{prefix}_HAADF*NC*'))
if len(filename_matches) == 0:
continue
pos_filenames = [filename for filename in filename_matches if '_PtNC' in filename]
neg_filenames = [filename for filename in filename_matches if '_NC' in filename]
if len(pos_filenames) > 0:
pos_filename = sorted(pos_filenames)[-1]
tif_filenames_list.append(pos_filename)
labels_list.append(1)
if len(neg_filenames) > 0:
neg_filename = sorted(neg_filenames)[-1]
tif_filenames_list.append(neg_filename)
labels_list.append(0)
return tif_filenames_list, labels_list
if __name__ == "__main__":
filenames_list = get_image_path_without_coords()
filename = filenames_list[0]
dataset = SlidingCropDataset(filename)