File size: 2,997 Bytes
b2ffc9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import numpy as np

from scipy.ndimage.filters import gaussian_filter, median_filter
from atoms_detection.fast_filters import median_filter_parallel
from PIL import Image


def preprocess_jpg(np_img: np.ndarray) -> np.ndarray:
    return np_img[:, :, 0]


def dl_prepro_image(np_img: np.ndarray, ruler_units=None, clip=1):
    # np_bg = gaussian_filter(np_img, sigma=20)
    if len(np_img.shape) == 3:
        np_img = preprocess_jpg(np_img)
    scale_factor = None
    if ruler_units is not None:
        try:
            ruler_size = get_ruler_size(np_img)
            np_img, scale_factor = rescale_img_to_target_pxls_nm(
                np_img, ruler_size, ruler_units
            )
        except Exception:
            pass

    print("WARNING, MANUAL CLIP USAGE")
    clip = 0.999
    max_allowed = np.quantile(np_img, q=clip)
    np_img = np.clip(np_img, a_min=0, a_max=max_allowed)
    try:
        np_bg = median_filter_parallel(np_img, 40, splits=4)
    except Exception as e:
        print(e)
        print("Median filter failed, using slower scipy version")
        np_bg = median_filter(np_img, 40)
    np_clean = np_img - np_bg
    np_clean[np_clean < 0] = 0
    np_normed = (np_clean - np_clean.min()) / (np_clean.max() - np_clean.min())
    # np_normed = (np_img - np_img.min()) / (np_img.max() - np_img.min())
    from matplotlib import pyplot as plt

    if scale_factor is not None:
        return np_normed, scale_factor
    return np_normed


def cv_prepro_image(img: np.ndarray) -> np.ndarray:
    bg_img = gaussian_filter(img, sigma=10)
    clean_img = img - bg_img
    normed_img = (clean_img - clean_img.min()) / (clean_img.max() - clean_img.min())
    return normed_img


def get_ruler_size(img: np.ndarray) -> int:
    ruler_start_location_percent = 0.0625  # empirically located here in samples
    ruler_start_coords = int(
        img.shape[0] * (1 - ruler_start_location_percent) - 1
    ), int(img.shape[1] * ruler_start_location_percent)
    if img[ruler_start_coords] != img.max():
        print("Ruler start position verification failed, skipping rescaling")
        raise Exception
    else:
        ruler_iterator = ruler_start_coords
        while img[ruler_iterator] == img[ruler_start_coords]:
            ruler_iterator = ruler_iterator[0], ruler_iterator[1] + 1
        return ruler_iterator[1] - ruler_start_coords[1]


def rescale_img_to_target_pxls_nm(
    img: np.ndarray, ruler_pixels: int, ruler_units: int, atom_prior=None
):
    target_scale = (
        512 / 15
    )  # original images were 512x512 and labelled 15nm across, 34 pixels per nano
    pixels_per_nanometer = ruler_pixels / ruler_units  # current pixels per nano
    scaling_factor = target_scale / pixels_per_nanometer
    new_dimensions = int(img.shape[0] * scaling_factor), int(
        img.shape[1] * scaling_factor
    )
    if atom_prior is None:
        return np.array(Image.fromarray(img).resize(new_dimensions)), scaling_factor
    else:
        raise NotImplementedError