Spaces:
Runtime error
Runtime error
import numpy as np | |
from scipy import signal | |
from scipy import ndimage | |
from scipy.fftpack import next_fast_len | |
from skimage.transform import rotate | |
from skimage._shared.utils import convert_to_float | |
from skimage.transform import warp | |
import matplotlib.pyplot as plt | |
import cv2 | |
from copy import deepcopy | |
def get_directional_std(image, theta=None,*, preserve_range=False): | |
if image.ndim != 2: | |
raise ValueError('The input image must be 2-D') | |
if theta is None: | |
theta = np.arange(180) | |
image = convert_to_float(image.copy(), preserve_range) #TODO: needed? | |
shape_min = min(image.shape) | |
img_shape = np.array(image.shape) | |
# Crop image to make it square | |
slices = tuple(slice(int(np.ceil(excess / 2)), | |
int(np.ceil(excess / 2) + shape_min)) | |
if excess > 0 else slice(None) | |
for excess in (img_shape - shape_min)) | |
image = image[slices] | |
shape_min = min(image.shape) | |
img_shape = np.array(image.shape) | |
radius = shape_min // 2 | |
coords = np.array(np.ogrid[:image.shape[0], :image.shape[1]], | |
dtype=object) | |
dist = ((coords - img_shape // 2) ** 2).sum(0) | |
outside_reconstruction_circle = dist > radius ** 2 | |
image[outside_reconstruction_circle] = 0 | |
valid_square_slice = slice(int(np.ceil(radius*(1-1/np.sqrt(2)))), int(np.ceil(radius*(1+1/np.sqrt(2)))) ) | |
# padded_image is always square | |
if image.shape[0] != image.shape[1]: | |
raise ValueError('padded_image must be a square') | |
center = image.shape[0] // 2 | |
result = np.zeros(len(theta)) | |
for i, angle in enumerate(np.deg2rad(theta)): | |
cos_a, sin_a = np.cos(angle), np.sin(angle) | |
R = np.array([[cos_a, sin_a, -center * (cos_a + sin_a - 1)], | |
[-sin_a, cos_a, -center * (cos_a - sin_a - 1)], | |
[0, 0, 1]]) | |
rotated = warp(image, R, clip=False) | |
result[i] = rotated[valid_square_slice, valid_square_slice].std(axis=0).mean() | |
return result | |
def acf2d(x, nlags=None): | |
xo = x - x.mean(axis=0) | |
n = len(x) | |
if nlags is None: | |
nlags = n -1 | |
lag_len = nlags | |
xi = np.arange(1, n + 1) | |
d = np.expand_dims(np.hstack((xi, xi[:-1][::-1])),1) | |
nobs = len(xo) | |
n = next_fast_len(2 * nobs + 1) | |
Frf = np.fft.fft(xo, n=n, axis=0) | |
acov = np.fft.ifft(Frf * np.conjugate(Frf), axis=0)[:nobs] / d[nobs - 1 :] | |
acov = acov.real | |
ac = acov[: nlags + 1] / acov[:1] | |
return ac | |
def get_period(acf_table, n_samples=50): | |
#TODO: use peak heights to select best candidates. use std to eliminate outliers | |
period_candidates = [] | |
period_candidates_hights = [] | |
for i in np.random.randint(0, acf_table.shape[1], min(acf_table.shape[1], n_samples)): | |
peaks = signal.find_peaks(acf_table[:,i])[0] | |
if len(peaks) == 0: | |
continue | |
peak_idx = peaks[0] | |
period_candidates.append(peak_idx) | |
period_candidates_hights.append(acf_table[peak_idx,i]) | |
period_candidates = np.array(period_candidates) | |
period_candidates_hights = np.array(period_candidates_hights) | |
if len(period_candidates) == 0: | |
return np.nan, np.nan | |
elif len(period_candidates) == 1: | |
return period_candidates[0], np.nan | |
q1, q3 = np.quantile(period_candidates, [0.25, 0.75]) | |
candidates_std = np.std(period_candidates[(period_candidates>=q1)&(period_candidates<=q3)]) | |
# return period_candidates, period_candidates_hights | |
return np.median(period_candidates), candidates_std | |
def get_rotation_with_confidence(padded_image, blur_size=4, make_plots=True): | |
std_by_angle = get_directional_std(cv2.blur(padded_image, (blur_size,blur_size))) | |
rotation_angle = np.argmin(std_by_angle) | |
rotation_quality = 1 - np.min(std_by_angle)/np.median(std_by_angle) | |
if make_plots: | |
plt.plot(std_by_angle) | |
plt.axvline(rotation_angle, c='k') | |
plt.title(f'quality: {rotation_quality:0.2f}') | |
return rotation_angle, rotation_quality | |
def calculate_autocorrelation(oriented_img, blur_kernel=(7,1), make_plots=True): | |
autocorrelation = acf2d(cv2.blur(oriented_img.T, blur_kernel)) | |
if make_plots: | |
fig, axs = plt.subplots(ncols=2, figsize=(12,6)) | |
axs[0].imshow(autocorrelation) | |
axs[1].plot(autocorrelation.sum(axis=1)) | |
return autocorrelation | |
def get_period_with_confidence(autocorrelation_tab, n_samples=30): | |
period, period_std = get_period(autocorrelation_tab, n_samples=n_samples) | |
if period_std == np.nan: | |
period_confidence = 0.001 | |
else: | |
period_confidence = period/(period+2*period_std) | |
return period, period_confidence | |
def calculate_white_fraction(img, blur_size=4, make_plots=True): #TODO: add mask | |
blurred = cv2.blur(img, (blur_size, blur_size)) | |
blurred_sum = blurred.sum(axis=0) | |
lower, upper = np.quantile(blurred_sum, [0.15, 0.85]) | |
sign = blurred_sum > (lower+upper)/2 | |
sign_change = sign[:-1] != sign[1:] | |
sign_change_indices = np.where(sign_change)[0] | |
if len(sign_change_indices) >= 2 + (sign[-1] == sign[0]): | |
cut_first = sign_change_indices[0]+1 | |
if sign[-1] == sign[0]: | |
cut_last = sign_change_indices[-2] | |
else: | |
cut_last = sign_change_indices[-1] | |
white_fraction = np.mean(sign[cut_first:cut_last]) | |
else: | |
white_fraction = np.nan | |
cut_first, cut_last = None, None | |
if make_plots: | |
fig, axs = plt.subplots(ncols=3, figsize=(16,6)) | |
blurred_sum_normalized = blurred_sum - blurred_sum.min() | |
blurred_sum_normalized /= blurred_sum_normalized.max() | |
axs[0].plot(blurred_sum_normalized) | |
axs[0].plot(sign) | |
axs[1].plot(blurred_sum_normalized[cut_first:cut_last]) | |
axs[1].plot(sign[cut_first:cut_last]) | |
axs[2].imshow(img, cmap='gray') | |
for i, idx in enumerate(sign_change_indices): | |
plt.axvline(idx, c=['r', 'lime'][i%2]) | |
fig.suptitle(f'fraction: {white_fraction:0.2f}') | |
return white_fraction | |
def process_img_crop(img, nm_per_px=1, make_plots=False, return_extra=False): | |
# image must be square | |
assert img.shape[0] == img.shape[1] | |
crop_size = img.shape[0] | |
# find orientation | |
rotation_angle, rotation_quality = get_rotation_with_confidence(img, blur_size=4, make_plots=make_plots) | |
# rotate and crop image | |
crop_margin = int((1 - 1/np.sqrt(2))*crop_size*0.5) | |
oriented_img = rotate(img, -rotation_angle)[2*crop_margin:-crop_margin, crop_margin:-crop_margin] | |
# calculate autocorrelation | |
autocorrelation = calculate_autocorrelation(oriented_img, blur_kernel=(7,1), make_plots=make_plots) | |
# find period | |
period, period_confidence = get_period_with_confidence(autocorrelation) | |
if make_plots: | |
print(f'period: {period}, confidence: {period_confidence}') | |
# find white fraction | |
white_fraction = calculate_white_fraction(oriented_img, make_plots=make_plots) | |
white_width = white_fraction*period | |
result = { | |
'direction': rotation_angle, | |
'direction confidence': rotation_quality, | |
'period': period*nm_per_px, | |
'period confidence': period_confidence, | |
'lumen width': white_width*nm_per_px | |
} | |
if return_extra: | |
result['extra'] = { | |
'autocorrelation': autocorrelation, | |
'oriented_img': oriented_img | |
} | |
return result | |
def get_top_k(a, k): | |
ind = np.argpartition(a, -k)[-k:] | |
return a[ind] | |
def get_crops(img, distance_map, crop_size, N_sample): | |
crop_r= np.sqrt(2)*crop_size / 2 | |
possible_positions_y, possible_positions_x = np.where(distance_map >= crop_r) | |
no_edge_mask = (possible_positions_y>crop_r) & \ | |
(possible_positions_x>crop_r) & \ | |
(possible_positions_y<(distance_map.shape[0]-crop_r)) & \ | |
(possible_positions_x<(distance_map.shape[1]-crop_r)) | |
possible_positions_x = possible_positions_x[no_edge_mask] | |
possible_positions_y = possible_positions_y[no_edge_mask] | |
N_available = len(possible_positions_x) | |
positions_indices = np.random.choice(np.arange(N_available), min(N_sample, N_available), replace=False) | |
for idx in positions_indices: | |
yield img[possible_positions_y[idx]-crop_size//2:possible_positions_y[idx]+crop_size//2,possible_positions_x[idx]-crop_size//2:possible_positions_x[idx]+crop_size//2].copy() | |
def sliced_mean(x, slice_size): | |
cs_y = np.cumsum(x, axis=0) | |
cs_y = np.concatenate((np.zeros((1, cs_y.shape[1]), dtype=cs_y.dtype), cs_y), axis=0) | |
slices_y = (cs_y[slice_size:] - cs_y[:-slice_size])/slice_size | |
cs_xy = np.cumsum(slices_y, axis=1) | |
cs_xy = np.concatenate((np.zeros((cs_xy.shape[0], 1), dtype=cs_xy.dtype), cs_xy), axis=1) | |
slices_xy = (cs_xy[:,slice_size:] - cs_xy[:,:-slice_size])/slice_size | |
return slices_xy | |
def sliced_var(x, slice_size): | |
x = x.astype('float64') | |
return sliced_mean(x**2, slice_size) - sliced_mean(x, slice_size)**2 | |
def select_samples(granum_image, granum_mask, crop_size=96, n_samples=64, granum_fraction_min=1.0, variance_p=2): | |
granum_occupancy = sliced_mean(granum_mask, crop_size) | |
possible_indices = np.stack(np.where(granum_occupancy >= granum_fraction_min), axis=1) | |
if variance_p == 0: | |
p = np.ones(len(possible_indices)) | |
else: | |
variance_map = sliced_var(granum_image, crop_size) | |
p = variance_map[possible_indices[:,0], possible_indices[:,1]]**variance_p | |
p /= np.sum(p) | |
chosen_indices = np.random.choice( | |
np.arange(len(possible_indices)), | |
min(len(possible_indices), n_samples), | |
replace=False, | |
p = p | |
) | |
crops = [] | |
for crop_idx, idx in enumerate(chosen_indices): | |
crops.append( | |
granum_image[ | |
possible_indices[idx,0]:possible_indices[idx,0]+crop_size, | |
possible_indices[idx,1]:possible_indices[idx,1]+crop_size | |
] | |
) | |
return np.array(crops) | |
def calculate_distance_map(mask): | |
padded = np.pad(mask, pad_width=1, mode='constant', constant_values=False) | |
distance_map_padded = ndimage.distance_transform_edt(padded) | |
return distance_map_padded[1:-1,1:-1] | |
def measure_object( | |
img, mask, | |
nm_per_px=1, n_tries = 3, | |
direction_thr_min = 0.07, direction_thr_enough = 0.1, | |
crop_size = 200, | |
**kwargs): | |
distance_map = calculate_distance_map(mask) | |
crop_size = min(crop_size, int(min(get_top_k(distance_map.flatten(), n_tries)*0.5**0.5))) | |
direction_confidence = 0 | |
best_stripes_data = {} | |
for i, img_crop in enumerate(get_crops(img, distance_map, crop_size, n_tries)): | |
stripes_data = process_img_crop(img_crop, nm_per_px=nm_per_px) | |
if stripes_data['direction confidence'] >= direction_confidence: | |
best_stripes_data = deepcopy(stripes_data) | |
direction_confidence = stripes_data['direction confidence'] | |
if direction_confidence > direction_thr_enough: | |
break | |
result = best_stripes_data | |
if direction_confidence >= direction_thr_min: | |
mask_oriented = rotate(mask, 90-result['direction'], resize=True).astype('bool') | |
idx_begin_x, idx_end_x = np.where(np.any(mask_oriented, axis=0))[0][np.array([0, -1])] | |
idx_begin_y, idx_end_y = np.where(np.any(mask_oriented, axis=1))[0][np.array([0, -1])] | |
result['mask_oriented'] = mask_oriented[idx_begin_y:idx_end_y, idx_begin_x:idx_end_x] | |
result['img_oriented'] = rotate(img, 90-result['direction'], resize=True)[idx_begin_y:idx_end_y, idx_begin_x:idx_end_x] | |
# measurements = measure_granum_shape(result['mask_oriented'], nm_per_px=nm_per_px, oriented=True) | |
# else: | |
# measurements = measure_granum_shape(mask, nm_per_px=nm_per_px, oriented=False) | |
# result.update(**measurements) | |
# N_layers = result['height'] / result['period'] | |
# if np.isfinite(N_layers): | |
# N_layers = round(N_layers) | |
return result | |
# def measure_object( | |
# img, mask, | |
# nm_per_px=1, n_tries = 3, | |
# direction_thr_min = 0.07, direction_thr_enough = 0.1, | |
# crop_size = 200, | |
# **kwargs): | |
# distance_map = calculate_distance_map(mask) | |
# crop_size = min(crop_size, int((min(get_top_k(distance_map.flatten(), n_tries)*0.5)**0.5))) | |
# direction_confidence = 0 | |
# best_stripes_data = {} | |
# for i, img_crop in enumerate(select_samples(img, mask, crop_size=crop_size, n_samples=n_tries)): | |
# stripes_data = process_img_crop(img_crop, nm_per_px=nm_per_px) | |
# if stripes_data['direction_confidence'] >= direction_confidence: | |
# best_stripes_data = deepcopy(stripes_data) | |
# direction_confidence = stripes_data['direction_confidence'] | |
# if direction_confidence > direction_thr_enough: | |
# break | |
# result = best_stripes_data | |
# if direction_confidence >= direction_thr_min: | |
# mask_oriented = rotate(mask, 90-result['direction'], resize=True).astype('bool') | |
# idx_begin_x, idx_end_x = np.where(np.any(mask_oriented, axis=0))[0][np.array([0, -1])] | |
# idx_begin_y, idx_end_y = np.where(np.any(mask_oriented, axis=1))[0][np.array([0, -1])] | |
# result['mask_oriented'] = mask_oriented[idx_begin_y:idx_end_y, idx_begin_x:idx_end_x] | |
# result['img_oriented'] = rotate(img, 90-result['direction'], resize=True)[idx_begin_y:idx_end_y, idx_begin_x:idx_end_x] | |
# # measurements = measure_granum_shape(result['mask_oriented'], nm_per_px=nm_per_px, oriented=True) | |
# # else: | |
# # measurements = measure_granum_shape(mask, nm_per_px=nm_per_px, oriented=False) | |
# # result.update(**measurements) | |
# # N_layers = result['height'] / result['period'] | |
# # if np.isfinite(N_layers): | |
# # N_layers = round(N_layers) | |
# return result #{**measurements, **best_stripes_data, 'N layers': N_layers} |