import numpy as np import torch import pytorch_lightning as pl import timm # from hydra.utils import instantiate from scipy.stats import circmean, circstd from scipy import ndimage from skimage.transform import resize from sampling import get_crop_batch from granum_utils import get_circle_mask import image_transforms from envelope_correction import calculate_best_angle_from_mask ## loss class ConfidenceScaler: def __init__(self, data: np.ndarray): self.data = data self.data.sort() def __call__(self, x): return np.searchsorted(self.data,x) / len(self.data) class PatchedPredictor: def __init__(self, model, crop_size=96, normalization=dict(mean=0,std=1), n_samples=32, mask=None,# 'circle', None filter_outliers=True, apply_radon=False, # apply Radon transform radon_size=(128,128), # (int, int) reshape radon transformed image to this shape, angle_confidence_threshold=0, use_envelope_correction=True ): self.model = model self.crop_size = crop_size self.normalization = normalization self.n_samples = n_samples if mask not in [None, 'circle']: raise ValueError(f'unknown mask {mask}') self.mask = mask self.filter_outliers = filter_outliers self.apply_radon = apply_radon self.radon_size = radon_size self.angle_confidence_threshold = angle_confidence_threshold self.use_envelope_correction = use_envelope_correction @torch.no_grad() def __call__(self, img: np.ndarray, mask: np.ndarray): pl.seed_everything(44) # get crops with different scales and rotation crops, angles_tta, scales_tta = get_crop_batch( img, mask, crop_size=self.crop_size, samples_per_scale=self.n_samples, use_variance_threshold=True ) if len(crops) == 0: return dict( est_angle=np.nan, est_angle_confidence=0., ) # preprocess batch (normalize, mask, transform) batch = self._preprocess_batch(crops) # predict for batch - we don't use period and lumen anymore preds_direction, preds_period, preds_lumen_width = self.model(batch) # # convert to numpy # preds_direction = preds_direction.numpy() # preds_period = preds_period.numpy() # preds_lumen_width = preds_lumen_width.numpy() # aggregate angles est_angles = (preds_direction - angles_tta) % 180 est_angle = circmean(est_angles, low=-90, high=90) + 90 est_angle_std = circstd(est_angles, low=-90, high=90) est_angle_confidence = self._std_to_confidence(est_angle_std, 10) # confidence 0.5 for std =10 degrees if est_angle_confidence < self.angle_confidence_threshold: est_angle = np.nan est_angle_confidence = 0. if self.use_envelope_correction and (not np.isnan(est_angle)): angle_correction = -calculate_best_angle_from_mask( ndimage.rotate(mask, -est_angle, reshape=True, order=0) ) est_angle += angle_correction return dict( est_angle=est_angle, est_angle_confidence=est_angle_confidence, ) def _apply_radon(self, batch): # may reauire circle mask crops_radon = image_transforms.batched_radon(batch.numpy()) crops_radon = np.transpose(resize(np.transpose(crops_radon, (1, 2, 0)), self.radon_size), (2, 0, 1)) return torch.tensor(crops_radon) def _preprocess_batch(self, batch): if self.mask == 'circle': mask = get_circle_mask(batch.shape[1]) batch[:,mask] = 0 if self.apply_radon: batch = self._apply_radon(batch) batch = ((batch/255) - self.normalization['mean'])/self.normalization['std'] return batch.unsqueeze(1) # add channel dimension def _filter_outliers(self, x, qmin=0.25, qmax=0.75): x_min, x_max = np.quantile(x, [qmin, qmax]) return x[(x>=x_min) & (x<=x_max)] def _std_to_confidence(self, x, x_thr, y_thr=0.5): """transform [0, inf] to [1,0], such that f(x_thr)=y_thr""" return 1 / (1+x*(1-y_thr)/(x_thr*y_thr)) class CosineLoss(torch.nn.Module): def __init__(self, p=1, degrees=False, scale=1): super().__init__() self.p = p self.degrees = degrees self.scale = scale def forward(self, x, y): if self.degrees: x = torch.deg2rad(x) y = torch.deg2rad(y) return torch.mean((1-torch.cos(x-y))**self.p) * self.scale ## model class AngleParser2d(torch.nn.Module): def __init__(self, angle_range=180): super().__init__() self.angle_range = angle_range def forward(self, batch): # r = torch.linalg.norm(batch, dim=1) preds_y_proj = torch.sigmoid(batch[:,0]) - 0.5 preds_x_proj = torch.sigmoid(batch[:,1]) - 0.5 preds_direction = self.angle_range/360.*torch.rad2deg(torch.arctan2(preds_y_proj, preds_x_proj)) return preds_direction class AngleRegularizer(torch.nn.Module): def __init__(self, strength=1.0, scale=1.0, p=2): super().__init__() self.strength = strength self.scale = scale self.p = p def forward(self, batch): r = torch.linalg.norm(batch, dim=1) return self.strength * torch.norm(r - self.scale, p=self.p) class AngleRegularizerLog(torch.nn.Module): def __init__(self, strength=1.0, scale=1.0, p=2): super().__init__() self.strength = strength self.scale = scale self.p = p def forward(self, batch): r = torch.linalg.norm(batch, dim=1) return self.strength * torch.norm(torch.log(r/self.scale), p=self.p) class StripsModel(pl.LightningModule): def __init__(self, model_name = 'resnet18', lr=0.001, optimizer_hparams=dict(), lr_hparams=dict(classname='MultiStepLR', kwargs=dict(milestones=[100, 150], gamma=0.1)), loss_hparams=dict(rotation_weight=10., lumen_fraction_weight=50.), angle_hparams=dict(angle_range=180.), regularizer_hparams=None, sigmoid_smoother=10. ): super().__init__() # Exports the hyperparameters to a YAML file, and create "self.hparams" namespace self.save_hyperparameters() # Create model - implemented in non-abstract classes self.model = timm.create_model(model_name, in_chans=1, num_classes=4) #2 + self.hparams.angle_hparams['ndim']) self.angle_parser = AngleParser2d(**self.hparams.angle_hparams) self.regularizer = self._get_regularizer(self.hparams.regularizer_hparams) self.losses = { 'direction': CosineLoss(2., True), 'period': torch.nn.functional.mse_loss, 'lumen_fraction': torch.nn.functional.mse_loss } self.losses_weights = { 'direction': self.hparams.loss_hparams['rotation_weight'], 'period': 1, 'lumen_fraction': self.hparams.loss_hparams['lumen_fraction_weight'], 'regularization': self.hparams.loss_hparams.get('regularization_weight', 0.) } def _get_regularizer(self, regularizer_params): if regularizer_params is None: return None else: return instantiate(regularizer_params) def forward(self, x, return_raw=False): """get predictions from image batch""" preds = self.model(x) # preds: logit angle_sin, logit angle_cos, period, logit lumen fraction or logit angle, period, logit lumen fraction preds_direction = self.angle_parser(preds) preds_period = preds[:,-2] preds_lumen_fraction = torch.sigmoid(preds[:,-1]*self.hparams.sigmoid_smoother) #lumen fraction is between 0 and 1, so we take sigmoid fo this outputs = [preds_direction, preds_period, preds_lumen_fraction] if return_raw: outputs.append(preds) return tuple(outputs) def configure_optimizers(self): # AdamW is Adam with a correct implementation of weight decay (see here # for details: https://arxiv.org/pdf/1711.05101.pdf) optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr, **self.hparams.optimizer_hparams) # scheduler = getattr(torch.optim.lr_scheduler, self.hparams.lr_hparams['classname'])(optimizer, **self.hparams.lr_hparams['kwargs']) scheduler = instantiate({**self.hparams.lr_hparams, '_partial_': True})(optimizer) return [optimizer], [scheduler] def process_batch_supervised(self, batch): """get predictions, losses and mean errors (MAE)""" # get predictions preds = {} preds['direction'], preds['period'], preds['lumen_fraction'], preds_raw = self.forward(batch['image'], return_raw=True) # preds: angle, period, lumen fraction, raw preds # calculate losses losses = { 'direction': self.losses['direction'](2*batch['direction'], 2*preds['direction']), 'period': self.losses['period'](batch['period'], preds['period']), 'lumen_fraction': self.losses['lumen_fraction'](batch['lumen_fraction'], preds['lumen_fraction']), } if self.regularizer is not None: losses['regularization'] = self.regularizer(preds_raw[:,:2]) losses['final'] = \ losses['direction']*self.losses_weights['direction'] + \ losses['period']*self.losses_weights['period'] + \ losses['lumen_fraction']*self.losses_weights['lumen_fraction'] + \ losses.get('regularization', 0.)*self.losses_weights.get('regularization', 0.) # calculate mean errors period_difference = np.mean(abs( batch['period'].detach().cpu().numpy() - \ preds['period'].detach().cpu().numpy() )) a1 = batch['direction'].detach().cpu().numpy() a2 = preds['direction'].detach().cpu().numpy() angle_difference = np.mean(0.5*np.degrees(np.arccos(np.cos(2*np.radians(a2-a1))))) lumen_fraction_difference = np.mean(abs(preds['lumen_fraction'].detach().cpu().numpy()-batch['lumen_fraction'].detach().cpu().numpy())) mae = { 'period': period_difference, 'direction': angle_difference, 'lumen_fraction': lumen_fraction_difference } return preds, losses, mae def log_all(self, losses, mae, prefix=''): self.log(f"{prefix}angle_loss", losses['direction'].item()) self.log(f"{prefix}period_loss", losses['period'].item()) self.log(f"{prefix}lumen_fraction_loss", losses['lumen_fraction'].item()) self.log(f"{prefix}period_difference", mae['period']) self.log(f"{prefix}angle_difference", mae['direction']) self.log(f"{prefix}lumen_fraction_difference", mae['lumen_fraction']) self.log(f"{prefix}loss", losses['final']) if 'regularization' in losses: self.log(f"{prefix}regularization_loss", losses['regularization'].item()) def training_step(self, batch, batch_idx): # "batch" is the output of the training data loader. preds, losses, mae = self.process_batch_supervised(batch) self.log_all(losses, mae, prefix='train_') return losses['final'] def validation_step(self, batch, batch_idx): preds, losses, mae = self.process_batch_supervised(batch) self.log_all(losses, mae, prefix='val_') def test_step(self, batch, batch_idx): preds, losses, mae = self.process_batch_supervised(batch) self.log_all(losses, mae, prefix='test_') class StripsModelLumenWidth(pl.LightningModule): def __init__(self, model_name = 'resnet18', lr=0.001, optimizer_hparams=dict(), lr_hparams=dict(classname='MultiStepLR', kwargs=dict(milestones=[100, 150], gamma=0.1)), loss_hparams=dict(rotation_weight=10., lumen_width_weight=50.), angle_hparams=dict(angle_range=180.), regularizer_hparams=None, sigmoid_smoother=10. ): super().__init__() # Exports the hyperparameters to a YAML file, and create "self.hparams" namespace self.save_hyperparameters() # Create model - implemented in non-abstract classes self.model = timm.create_model(model_name, in_chans=1, num_classes=4) #2 + self.hparams.angle_hparams['ndim']) self.angle_parser = AngleParser2d(**self.hparams.angle_hparams) self.regularizer = self._get_regularizer(self.hparams.regularizer_hparams) self.losses = { 'direction': CosineLoss(2., True), 'period': torch.nn.functional.mse_loss, 'lumen_width': torch.nn.functional.mse_loss } self.losses_weights = { 'direction': self.hparams.loss_hparams['rotation_weight'], 'period': 1, 'lumen_width': self.hparams.loss_hparams['lumen_width_weight'], 'regularization': self.hparams.loss_hparams.get('regularization_weight', 0.) } def _get_regularizer(self, regularizer_params): if regularizer_params is None: return None else: return instantiate(regularizer_params) def forward(self, x, return_raw=False): """get predictions from image batch""" preds = self.model(x) # preds: logit angle_sin, logit angle_cos, period, logit lumen fraction or logit angle, period, logit lumen fraction preds_direction = self.angle_parser(preds) preds_period = preds[:,-2] preds_lumen_width = preds[:,-1] #lumen fraction is between 0 and 1, so we take sigmoid fo this outputs = [preds_direction, preds_period, preds_lumen_width] if return_raw: outputs.append(preds) return tuple(outputs) def configure_optimizers(self): # AdamW is Adam with a correct implementation of weight decay (see here # for details: https://arxiv.org/pdf/1711.05101.pdf) optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr, **self.hparams.optimizer_hparams) # scheduler = getattr(torch.optim.lr_scheduler, self.hparams.lr_hparams['classname'])(optimizer, **self.hparams.lr_hparams['kwargs']) scheduler = instantiate({**self.hparams.lr_hparams, '_partial_': True})(optimizer) return [optimizer], [scheduler] def process_batch_supervised(self, batch): """get predictions, losses and mean errors (MAE)""" # get predictions preds = {} preds['direction'], preds['period'], preds['lumen_width'], preds_raw = self.forward(batch['image'], return_raw=True) # preds: angle, period, lumen fraction, raw preds # calculate losses losses = { 'direction': self.losses['direction'](2*batch['direction'], 2*preds['direction']), 'period': self.losses['period'](batch['period'], preds['period']), 'lumen_width': self.losses['lumen_width'](batch['lumen_width'], preds['lumen_width']), } if self.regularizer is not None: losses['regularization'] = self.regularizer(preds_raw[:,:2]) losses['final'] = \ losses['direction']*self.losses_weights['direction'] + \ losses['period']*self.losses_weights['period'] + \ losses['lumen_width']*self.losses_weights['lumen_width'] + \ losses.get('regularization', 0.)*self.losses_weights.get('regularization', 0.) # calculate mean errors period_difference = np.mean(abs( batch['period'].detach().cpu().numpy() - \ preds['period'].detach().cpu().numpy() )) a1 = batch['direction'].detach().cpu().numpy() a2 = preds['direction'].detach().cpu().numpy() angle_difference = np.mean(0.5*np.degrees(np.arccos(np.cos(2*np.radians(a2-a1))))) lumen_width_difference = np.mean(abs(preds['lumen_width'].detach().cpu().numpy()-batch['lumen_width'].detach().cpu().numpy())) lumen_fraction_pred = preds['lumen_width'].detach().cpu().numpy()/preds['period'].detach().cpu().numpy() lumen_fraction_gt = batch['lumen_width'].detach().cpu().numpy()/batch['period'].detach().cpu().numpy() lumen_fraction_difference = np.mean(abs(lumen_fraction_pred-lumen_fraction_gt)) mae = { 'period': period_difference, 'direction': angle_difference, 'lumen_width': lumen_width_difference, 'lumen_fraction': lumen_fraction_difference } return preds, losses, mae def log_all(self, losses, mae, prefix=''): for k, v in losses.items(): self.log(f'{prefix}{k}_loss', v.item() if isinstance(v, torch.Tensor) else v) for k, v in mae.items(): self.log(f'{prefix}{k}_difference', v.item() if isinstance(v, torch.Tensor) else v) def training_step(self, batch, batch_idx): # "batch" is the output of the training data loader. preds, losses, mae = self.process_batch_supervised(batch) self.log_all(losses, mae, prefix='train_') return losses['final'] def validation_step(self, batch, batch_idx): preds, losses, mae = self.process_batch_supervised(batch) self.log_all(losses, mae, prefix='val_') def test_step(self, batch, batch_idx): preds, losses, mae = self.process_batch_supervised(batch) self.log_all(losses, mae, prefix='test_') # class StripsModel(StripsModelGeneral): # def __init__(self, model_name, *args, **kwargs): # super().__init__( *args, **kwargs) # self.model = timm.create_model(model_name, in_chans=1, num_classes=4) # def forward(self, x): # """get predictions from image batch""" # preds = self.model(x) # preds: logit angle_sin, logit angle_cos, period, logit lumen fraction # preds_sin = 1. - 2*torch.sigmoid(preds[:,0]) # preds_cos = 1. - 2*torch.sigmoid(preds[:,1]) # preds_direction = 0.5*torch.rad2deg(torch.arctan2(preds_sin, preds_cos)) # preds_period = preds[:,2] # preds_lumen_fraction = torch.sigmoid(preds[:,3]) #lumen fraction is between 0 and 1, so we take sigmoid fo this # return preds_direction, preds_period, preds_lumen_fraction # class StripsModelAngle1(StripsModelGeneral): # def __init__(self, model_name, *args, **kwargs): # super().__init__( *args, **kwargs) # self.model = timm.create_model(model_name, in_chans=1, num_classes=3) # def forward(self, x): # """get predictions from image batch""" # preds = self.model(x) # preds: logit angle_sin, logit angle # preds_direction = torch.pi * torch.sigmoid(preds[:,0]) # preds_period = preds[:,1] # preds_lumen_fraction = torch.sigmoid(preds[:,2]) #lumen fraction is between 0 and 1, so we take sigmoid fo this # return preds_direction, preds_period, preds_lumen_fraction