atom-detection / atoms_detection /dl_detection.py
Romain Graux
Initial commit with ml code and webapp
b2ffc9b
raw
history blame
5.08 kB
import os
from typing import Tuple, List
import torch
import numpy as np
import torch.nn
import torch.nn.functional
from atoms_detection.detection import Detection
from atoms_detection.training_model import model_pipeline
from atoms_detection.image_preprocessing import dl_prepro_image
from utils.constants import ModelArgs
from utils.paths import PREDS_PATH
class DLDetection(Detection):
def __init__(self,
model_name: ModelArgs,
ckpt_filename: str,
dataset_csv: str,
threshold: float,
detections_path: str,
inference_cache_path: str,
batch_size: int = 64,
):
self.model_name = model_name
self.ckpt_filename = ckpt_filename
self.device = self.get_torch_device()
self.batch_size = batch_size
self.stride = 1
self.padding = 10
self.window_size = (21, 21)
super().__init__(dataset_csv, threshold, detections_path, inference_cache_path)
@staticmethod
def get_torch_device():
if torch.backends.mps.is_available():
device = torch.device("mps")
elif torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
return device
def sliding_window(self, image: np.ndarray, padding: int = 0) -> Tuple[int, int, np.ndarray]:
# slide a window across the image
x_to_center = self.window_size[0] // 2 - 1 if self.window_size[0] % 2 == 0 else self.window_size[0] // 2
y_to_center = self.window_size[1] // 2 - 1 if self.window_size[1] % 2 == 0 else self.window_size[1] // 2
for y in range(0, image.shape[0] - self.window_size[1]+1, self.stride):
for x in range(0, image.shape[1] - self.window_size[0]+1, self.stride):
# yield the current window
center_x = x + x_to_center
center_y = y + y_to_center
yield center_x-padding, center_y-padding, image[y:y + self.window_size[1], x:x + self.window_size[0]]
def batch_sliding_window(self, image: np.ndarray, padding: int = 0) -> Tuple[List[int], List[int], List[np.ndarray]]:
x_idx_list = []
y_idx_list = []
images_list = []
count = 0
for _x, _y, _img in self.sliding_window(image, padding=padding):
x_idx_list.append(_x)
y_idx_list.append(_y)
images_list.append(_img)
count += 1
if count == self.batch_size:
yield x_idx_list, y_idx_list, images_list
x_idx_list = []
y_idx_list = []
images_list = []
count = 0
if count != 0:
yield x_idx_list, y_idx_list, images_list
def padding_image(self, img: np.ndarray) -> np.ndarray:
image_padded = np.zeros((img.shape[0] + self.padding*2, img.shape[1] + self.padding*2))
image_padded[self.padding:-self.padding, self.padding:-self.padding] = img
return image_padded
def load_model(self) -> torch.nn.Module:
checkpoint = torch.load(self.ckpt_filename, map_location=self.device)
model = model_pipeline[self.model_name](num_classes=2).to(self.device)
model.load_state_dict(checkpoint['state_dict'])
model.eval()
return model
def images_to_torch_input(self, images_list: List[np.ndarray]) -> torch.Tensor:
expanded_img = np.expand_dims(images_list, axis=1)
input_tensor = torch.from_numpy(expanded_img).float()
input_tensor = input_tensor.to(self.device)
return input_tensor
def get_prediction_map(self, padded_image: np.ndarray) -> np.ndarray:
_shape = padded_image.shape
pred_map = np.zeros((_shape[0] - self.padding*2, _shape[1] - self.padding*2))
model = self.load_model()
for x_idxs, y_idxs, image_crops in self.batch_sliding_window(padded_image, padding=self.padding):
torch_input = self.images_to_torch_input(image_crops)
output = model(torch_input)
pred_prob = torch.nn.functional.softmax(output, 1)
pred_prob = pred_prob.detach().cpu().numpy()[:, 1]
pred_map[np.array(y_idxs), np.array(x_idxs)] = pred_prob
return pred_map
def image_to_pred_map(self, img: np.ndarray, return_intermediate: bool = False) -> np.ndarray:
preprocessed_img = dl_prepro_image(img)
print(f"preprocessed_img.shape: {preprocessed_img.shape}, μ: {np.mean(preprocessed_img)}, σ: {np.std(preprocessed_img)}")
padded_image = self.padding_image(preprocessed_img)
print(f"padded_image.shape: {padded_image.shape}, μ: {np.mean(padded_image)}, σ: {np.std(padded_image)}")
pred_map = self.get_prediction_map(padded_image)
print(f"pred_map.shape: {pred_map.shape}, μ: {np.mean(pred_map)}, σ: {np.std(pred_map)}")
if return_intermediate:
return preprocessed_img, padded_image, pred_map
return pred_map