Spaces:
Runtime error
Runtime error
import torch | |
import ipywidgets | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from IPython.display import display | |
from itertools import chain, islice | |
from ipywidgets import interactive, widgets | |
def _create_label(text:str)->ipywidgets.widgets.Label: | |
"Create label widget" | |
label = widgets.Label( | |
text, | |
layout=widgets.Layout( | |
width='100%', | |
display='flex', | |
justify_content="center" | |
) | |
) | |
return label | |
def _create_slider( | |
slider_min: int, | |
slider_max: int, | |
value: int, | |
step: int=1, | |
description:str ='', | |
continuous_update: bool=True, | |
readout: bool=False, | |
slider_type: str='IntSlider', | |
**kwargs)->ipywidgets.widgets: | |
"Create slider widget" | |
slider = getattr(widgets, slider_type)( | |
min=slider_min, | |
max=slider_max, | |
step=step, | |
value=value, | |
description=description, | |
continuous_update=continuous_update, | |
readout = readout, | |
layout=widgets.Layout(width='99%', min_width='200px'), | |
style={'description_width': 'initial'}, | |
**kwargs | |
) | |
return slider | |
def _create_button(description:str)->ipywidgets.widgets.Button: | |
"Create button widget" | |
button = widgets.Button( | |
description=description, | |
layout=widgets.Layout( | |
width='95%', | |
margin='5px 5px' | |
) | |
) | |
return button | |
def _create_togglebutton(description: str, | |
value: int, | |
**kwargs)->ipywidgets.widgets.Button: | |
"Create toggle button widget" | |
button = widgets.ToggleButton( | |
description=description, | |
value = value, | |
layout=widgets.Layout( | |
width='95%', | |
margin='5px 5px' | |
), **kwargs | |
) | |
return button | |
class BasicViewer(): | |
""" Base class for viewing TensorDicom3D objects. | |
Args: | |
x: main image object to view as rank 3 tensor | |
y: either a segmentation mask as as rank 3 tensor or a label as str. | |
prediction: a class predicton as str | |
description: description of the whole image | |
figsize: size of image, passed as plotting argument | |
cmap: colormap for the image | |
Returns: | |
Instance of BasicViewer | |
""" | |
def __init__(self, x:torch.Tensor, y=None, prediction:str=None, description: str=None, | |
figsize=(3, 3), cmap:str='bone'): | |
assert x.ndim == 3, f"x.ndim needs to be equal to but is {x.ndim}" | |
if isinstance(y, torch.Tensor): | |
assert x.shape == y.shape, f"Shapes of x {x.shape} and y {y.shape} do not match" | |
self.x=x | |
self.y=y | |
self.prediction=prediction | |
self.description=description | |
self.figsize=figsize | |
self.cmap=cmap | |
self.with_mask = isinstance(y, torch.Tensor) | |
self.slice_range = (1, len(x)) # len(x) == im.shape[0] | |
def _plot_slice(self, im_slice, with_mask, px_range): | |
"Plot slice of image" | |
fig, ax = plt.subplots(1, 1, figsize=self.figsize) | |
ax.imshow(self.x[im_slice-1, :, :].clip(*px_range), cmap=self.cmap) | |
if isinstance(self.y, (torch.Tensor)) and with_mask: | |
ax.imshow(self.y[im_slice-1, :, :], cmap='jet', alpha = 0.25) | |
plt.axis('off') | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
plt.show() | |
def _create_image_box(self, figsize): | |
"Create widget items, order them in item_box and generate view box" | |
items = [] | |
if self.description: plot_description = _create_label(self.description) | |
if isinstance(self.y, str): | |
label = f'{self.y} | {self.prediction}' if self.prediction else self.y | |
if self.prediction: | |
font_color = 'green' if self.y == self.prediction else 'red' | |
y_label = _create_label(r'\(\color{' + font_color + '} {' + label + '}\)') | |
else: | |
y_label = _create_label(label) | |
else: y_label = _create_label(' ') | |
slice_slider = _create_slider( | |
slider_min = min(self.slice_range), | |
slider_max = max(self.slice_range), | |
value = max(self.slice_range)//2, | |
readout = True) | |
toggle_mask_button = _create_togglebutton('Show Mask', True) | |
range_slider = _create_slider( | |
slider_min = self.x.min().numpy(), | |
slider_max = self.x.max().numpy(), | |
value = [self.x.min().numpy(), self.x.max().numpy()], | |
slider_type = 'FloatRangeSlider' if torch.is_floating_point(self.x) else 'IntRandSlider', | |
step = 0.01 if torch.is_floating_point(self.x) else 1, | |
readout=True) | |
image_output = widgets.interactive_output( | |
f = self._plot_slice, | |
controls = {'im_slice': slice_slider, | |
'with_mask': toggle_mask_button, | |
'px_range': range_slider}) | |
image_output.layout.height = f'{self.figsize[0]/1.2}in' # suppress flickering | |
image_output.layout.width = f'{self.figsize[1]/1.2}in' # suppress flickering | |
if self.description: items.append(plot_description) | |
items.append(y_label) | |
items.append(range_slider) | |
items.append(image_output) | |
if isinstance(self.y, torch.Tensor): | |
slice_slider = widgets.HBox([slice_slider, toggle_mask_button]) | |
items.append(slice_slider) | |
image_box=widgets.VBox( | |
items, | |
layout = widgets.Layout( | |
border = 'none', | |
margin = '10px 5px 0px 0px', | |
padding = '5px')) | |
return image_box | |
def _generate_views(self): | |
image_box = self._create_image_box(self.figsize) | |
self.box = widgets.HBox(children=[image_box]) | |
def image_box(self): | |
return self._create_image_box(self.figsize) | |
def show(self): | |
self._generate_views() | |
plt.style.use('default') | |
display(self.box) | |
class DicomExplorer(BasicViewer): | |
""" DICOM viewer for basic image analysis inside iPython notebooks. | |
Can display a single 3D volume together with a segmentation mask, a histogram | |
of voxel/pixel values and some summary statistics. | |
Allows simple windowing by clipping the pixel/voxel values to a region, which | |
can be manually specified. | |
""" | |
vbox_layout = widgets.Layout( | |
margin = '10px 5px 5px 5px', | |
padding = '5px', | |
display='flex', | |
flex_flow='column', | |
align_items='center', | |
min_width = '250px') | |
def _plot_hist(self, px_range): | |
x = self.x.numpy().flatten() | |
fig, ax = plt.subplots(figsize=self.figsize) | |
N, bins, patches = plt.hist(x, 100, color='grey') | |
lwr = int(px_range[0] * 100/max(x)) | |
upr = int(np.ceil(px_range[1] * 100/max(x))) | |
for i in range(0,lwr): | |
patches[i].set_facecolor('grey' if lwr > 0 else 'darkblue') | |
for i in range(lwr, upr): | |
patches[i].set_facecolor('darkblue') | |
for i in range(upr,100): | |
patches[i].set_facecolor('grey' if upr < 100 else 'darkblue') | |
plt.show() | |
def _image_summary(self, px_range): | |
x = self.x.clip(*px_range) | |
diffs = x - x.mean() | |
var = torch.mean(torch.pow(diffs, 2.0)) | |
std = torch.pow(var, 0.5) | |
zscores = diffs / std | |
skews = torch.mean(torch.pow(zscores, 3.0)) | |
kurt = torch.mean(torch.pow(zscores, 4.0)) - 3.0 | |
table = f'Statistics:\n' + \ | |
f' Mean px value: {x.mean()} \n' + \ | |
f' Std of px values: {x.std()} \n' + \ | |
f' Min px value: {x.min()} \n' + \ | |
f' Max px value: {x.max()} \n' + \ | |
f' Median px value: {x.median()} \n' + \ | |
f' Skewness: {skews} \n' + \ | |
f' Kurtosis: {kurt} \n\n' + \ | |
f'Tensor properties \n' + \ | |
f' Tensor shape: {tuple(x.shape)}\n' + \ | |
f' Tensor dtype: {x.dtype}' | |
print(table) | |
def _generate_views(self): | |
slice_slider = _create_slider( | |
slider_min = min(self.slice_range), | |
slider_max = max(self.slice_range), | |
value = max(self.slice_range)//2, | |
readout = True) | |
toggle_mask_button = _create_togglebutton('Show Mask', True) | |
range_slider = _create_slider( | |
slider_min = self.x.min().numpy(), | |
slider_max = self.x.max().numpy(), | |
value = [self.x.min().numpy(), self.x.max().numpy()], | |
continuous_update=False, | |
slider_type = 'FloatRangeSlider' if torch.is_floating_point(self.x) else 'IntRandSlider', | |
step = 0.01 if torch.is_floating_point(self.x) else 1) | |
image_output = widgets.interactive_output( | |
f = self._plot_slice, | |
controls = {'im_slice': slice_slider, | |
'with_mask': toggle_mask_button, | |
'px_range': range_slider}) | |
image_output.layout.height = f'{self.figsize[0]/1.2}in' # suppress flickering | |
image_output.layout.width = f'{self.figsize[1]/1.2}in' # suppress flickering | |
if isinstance(self.y, torch.Tensor): | |
slice_slider = widgets.HBox([slice_slider, toggle_mask_button]) | |
hist_output = widgets.interactive_output( | |
f = self._plot_hist, | |
controls = {'px_range': range_slider}) | |
hist_output.layout.height = f'{self.figsize[0]/1.2}in' # suppress flickering | |
hist_output.layout.width = f'{self.figsize[1]/1.2}in' # suppress flickering | |
toggle_mask_button = _create_togglebutton('Show Mask', True) | |
table_output = widgets.interactive_output( | |
f = self._image_summary, | |
controls = {'px_range': range_slider}) | |
table_box = widgets.VBox([table_output], layout=self.vbox_layout) | |
hist_box = widgets.VBox( | |
[hist_output, range_slider], | |
layout=self.vbox_layout) | |
image_box = widgets.VBox( | |
[image_output, slice_slider], | |
layout=self.vbox_layout) | |
self.box = widgets.HBox( | |
[image_box, hist_box, table_box], | |
layout = widgets.Layout( | |
border = 'solid 1px lightgrey', | |
margin = '10px 5px 0px 0px', | |
padding = '5px', | |
width = f'{self.figsize[1]*2 + 3}in')) | |
class ListViewer(object): | |
""" Display multipple images with their masks or labels/predictions. | |
Arguments: | |
x (tuple, list): Tensor objects to view | |
y (tuple, list): Tensor objects (in case of segmentation task) or class labels as string. | |
predictions (str): Class predictions | |
cmap: colormap for display of `x` | |
max_n: maximum number of items to display | |
""" | |
def __init__(self, x:(list, tuple), y=None, prediction:str=None, description: str=None, | |
figsize=(4, 4), cmap:str='bone', max_n = 9): | |
self.slice_range = (1, len(x)) | |
x = x[0:max_n] | |
if y: y = y[0:max_n] | |
self.x=x | |
self.y=y | |
self.prediction=prediction | |
self.description=description | |
self.figsize=figsize | |
self.cmap=cmap | |
self.max_n=max_n | |
def _generate_views(self): | |
n_images = len(self.x) | |
image_grid, image_list = [], [] | |
for i in range(0, n_images): | |
image = self.x[i] | |
mask = self.y[i] if isinstance(self.y, list) else None | |
pred = self.prediction[i] if self.prediction else None | |
image_list.append( | |
BasicViewer( | |
x = image, | |
y = mask, | |
prediction = pred, | |
figsize = self.figsize, | |
cmap = self.cmap) | |
.image_box) | |
if (i+1) % np.ceil(np.sqrt(n_images)) == 0 or i == n_images - 1: | |
image_grid.append(widgets.HBox(image_list)) | |
image_list = [] | |
self.box = widgets.VBox(children=image_grid) | |
def show(self): | |
self._generate_views() | |
plt.style.use('default') | |
display(self.box) |